fmeval.data_loaders.util

  1import logging
  2import os
  3import boto3
  4import botocore.errorfactory
  5import ray.data
  6import urllib.parse
  7
  8from typing import Type, Optional
  9from fmeval import util
 10from fmeval.constants import (
 11    MIME_TYPE_JSON,
 12    MIME_TYPE_JSONLINES,
 13    PARTITION_MULTIPLIER,
 14    SEED,
 15    MAX_ROWS_TO_TAKE,
 16)
 17from fmeval.data_loaders.data_sources import DataSource, LocalDataFile, S3DataFile, DataFile, S3Uri, get_s3_client
 18from fmeval.data_loaders.json_data_loader import JsonDataLoaderConfig, JsonDataLoader
 19from fmeval.data_loaders.json_parser import JsonParser
 20from fmeval.data_loaders.data_config import DataConfig
 21from fmeval.util import get_num_actors
 22from fmeval.exceptions import EvalAlgorithmClientError, EvalAlgorithmInternalError
 23from fmeval.perf_util import timed_block
 24
 25client = boto3.client("s3")
 26logger = logging.getLogger(__name__)
 27
 28
 29def get_dataset(config: DataConfig, num_records: Optional[int] = None) -> ray.data.Dataset:
 30    """
 31    Util method to load Ray datasets using an input DataConfig.
 32
 33    :param config: Input DataConfig
 34    :param num_records: the number of records to sample from the dataset
 35    """
 36    # The following setup is necessary to instruct Ray to preserve the
 37    # order of records in the datasets
 38    ctx = ray.data.DataContext.get_current()
 39    ctx.execution_options.preserve_order = True
 40    with timed_block(f"Loading dataset {config.dataset_name}", logger):
 41        data_source = get_data_source(config.dataset_uri)
 42        data_loader_config = _get_data_loader_config(data_source, config)
 43        data_loader = _get_data_loader(config.dataset_mime_type)
 44        data = data_loader.load_dataset(data_loader_config)
 45        count = data.count()
 46        util.require(count > 0, "Data has to have at least one record")
 47        if num_records and num_records > 0:  # pragma: no branch
 48            # TODO update sampling logic - current logic is biased towards first MAX_ROWS_TO_TAKE rows
 49            num_records = min(num_records, count)
 50            # We are using to_pandas, sampling with Pandas dataframe, and then converting back to Ray Dataset to use
 51            # Pandas DataFrame's ability to sample deterministically. This is temporary workaround till Ray solves this
 52            # issue: https://github.com/ray-project/ray/issues/40406
 53            if count > MAX_ROWS_TO_TAKE:
 54                # If count is larger than 100000, we take the first 100000 row, and then sample from that to
 55                # maintain deterministic behaviour. We are using take_batch to get a pandas dataframe of size
 56                # MAX_ROWS_TO_TAKE when the size of original dataset is greater than MAX_ROWS_TO_TAKE. This is to avoid
 57                # failures in driver node by pulling too much data.
 58                pandas_df = data.take_batch(batch_size=MAX_ROWS_TO_TAKE, batch_format="pandas")
 59            else:
 60                pandas_df = data.to_pandas()
 61            sampled_df = pandas_df.sample(num_records, random_state=SEED)
 62            data = ray.data.from_pandas(sampled_df)
 63        data = data.repartition(get_num_actors() * PARTITION_MULTIPLIER).materialize()
 64    return data
 65
 66
 67def _get_data_loader_config(data_source: DataSource, config: DataConfig) -> JsonDataLoaderConfig:
 68    """
 69    Returns a dataloader config based on the dataset MIME type specified in `config`.
 70
 71    :param data_source: The dataset's DataSource object.
 72    :param config: Configures the returned dataloader config.
 73    :returns: A dataloader config object, created from `data_source` and `config`.
 74    """
 75    if config.dataset_mime_type == MIME_TYPE_JSON:
 76        if not isinstance(data_source, DataFile):
 77            raise EvalAlgorithmInternalError(
 78                f"JSON datasets must be stored in a single file. " f"Provided dataset has type {type(data_source)}."
 79            )
 80        return JsonDataLoaderConfig(
 81            parser=JsonParser(config),
 82            data_file=data_source,
 83            dataset_mime_type=MIME_TYPE_JSON,
 84            dataset_name=config.dataset_name,
 85        )
 86    elif config.dataset_mime_type == MIME_TYPE_JSONLINES:
 87        if not isinstance(data_source, DataFile):
 88            raise EvalAlgorithmInternalError(
 89                f"JSONLines datasets must be stored in a single file. "
 90                f"Provided dataset has type {type(data_source)}."
 91            )
 92        return JsonDataLoaderConfig(
 93            parser=JsonParser(config),
 94            data_file=data_source,
 95            dataset_mime_type=MIME_TYPE_JSONLINES,
 96            dataset_name=config.dataset_name,
 97        )
 98    else:  # pragma: no cover
 99        raise EvalAlgorithmInternalError(
100            "Dataset MIME types other than JSON and JSON Lines are not supported. "
101            f"MIME type detected from config is {config.dataset_mime_type}."
102        )
103
104
105def _get_data_loader(dataset_mime_type: str) -> Type[JsonDataLoader]:
106    """
107    Returns the dataloader class corresponding to the given dataset MIME type.
108
109    :param dataset_mime_type: Determines which dataloader class to return.
110    :returns: A dataloader class.
111    """
112    if dataset_mime_type == MIME_TYPE_JSON:
113        return JsonDataLoader
114    elif dataset_mime_type == MIME_TYPE_JSONLINES:
115        return JsonDataLoader
116    else:  # pragma: no cover
117        raise EvalAlgorithmInternalError(
118            "Dataset MIME types other than JSON and JSON Lines are not supported. "
119            f"MIME type detected from config is {dataset_mime_type}."
120        )
121
122
123def get_data_source(dataset_uri: str) -> DataSource:
124    """
125    Validates a dataset URI and returns the corresponding DataSource object
126    :param dataset_uri: local dataset path or s3 dataset uri
127    :return: DataSource object
128    """
129    if _is_valid_local_path(dataset_uri):
130        return _get_local_data_source(dataset_uri)
131    elif _is_valid_s3_uri(dataset_uri):
132        return _get_s3_data_source(dataset_uri)
133    else:
134        raise EvalAlgorithmClientError(f"Invalid dataset path: {dataset_uri}")
135
136
137def _get_local_data_source(dataset_uri) -> LocalDataFile:
138    """
139    :param dataset_uri: local dataset path
140    :return: LocalDataFile object with dataset uri
141    """
142    absolute_local_path = os.path.abspath(urllib.parse.urlparse(dataset_uri).path)
143    if os.path.isfile(absolute_local_path):
144        return LocalDataFile(absolute_local_path)
145    if os.path.isdir(absolute_local_path):
146        # TODO: extend support to directories
147        raise EvalAlgorithmClientError("Please provide a local file path instead of a directory path.")
148    raise EvalAlgorithmClientError(f"Invalid local path: {dataset_uri}")
149
150
151def _get_s3_data_source(dataset_uri) -> S3DataFile:
152    """
153    :param dataset_uri: s3 dataset uri
154    :return: S3DataFile object with dataset uri
155    """
156    s3_client = get_s3_client(dataset_uri)
157    s3_uri = S3Uri(dataset_uri)
158    s3_obj = s3_client.get_object(Bucket=s3_uri.bucket, Key=s3_uri.key)
159    if "application/x-directory" in s3_obj["ContentType"]:
160        # TODO: extend support to directories
161        raise EvalAlgorithmClientError("Please provide a s3 file path instead of a directory path.")
162    else:
163        # There isn't a good way to check if s3_obj corresponds specifically to a file,
164        # so we treat anything that is not a directory as a file.
165        return S3DataFile(dataset_uri)
166
167
168def _is_valid_s3_uri(uri: str) -> bool:
169    """
170    :param uri: s3 file path
171    :return: True if uri is a valid s3 path, False otherwise
172    """
173    parsed_url = urllib.parse.urlparse(uri)
174    if parsed_url.scheme.lower() not in ["s3", "s3n", "s3a"]:
175        return False
176    try:
177        s3_client = get_s3_client(uri)
178        s3_uri = S3Uri(uri)
179        s3_client.get_object(Bucket=s3_uri.bucket, Key=s3_uri.key)
180        return True
181    except botocore.errorfactory.ClientError:
182        return False
183
184
185def _is_valid_local_path(path: str) -> bool:
186    """
187    :param path: local file path
188    :return: True if path is a valid local path, False otherwise
189    """
190    parsed_url = urllib.parse.urlparse(path)
191    return parsed_url.scheme in ["", "file"] and os.path.exists(parsed_url.path)
client = <botocore.client.S3 object>
logger = <Logger fmeval.data_loaders.util (INFO)>
def get_dataset( config: fmeval.data_loaders.data_config.DataConfig, num_records: Optional[int] = None) -> ray.data.dataset.Dataset:
30def get_dataset(config: DataConfig, num_records: Optional[int] = None) -> ray.data.Dataset:
31    """
32    Util method to load Ray datasets using an input DataConfig.
33
34    :param config: Input DataConfig
35    :param num_records: the number of records to sample from the dataset
36    """
37    # The following setup is necessary to instruct Ray to preserve the
38    # order of records in the datasets
39    ctx = ray.data.DataContext.get_current()
40    ctx.execution_options.preserve_order = True
41    with timed_block(f"Loading dataset {config.dataset_name}", logger):
42        data_source = get_data_source(config.dataset_uri)
43        data_loader_config = _get_data_loader_config(data_source, config)
44        data_loader = _get_data_loader(config.dataset_mime_type)
45        data = data_loader.load_dataset(data_loader_config)
46        count = data.count()
47        util.require(count > 0, "Data has to have at least one record")
48        if num_records and num_records > 0:  # pragma: no branch
49            # TODO update sampling logic - current logic is biased towards first MAX_ROWS_TO_TAKE rows
50            num_records = min(num_records, count)
51            # We are using to_pandas, sampling with Pandas dataframe, and then converting back to Ray Dataset to use
52            # Pandas DataFrame's ability to sample deterministically. This is temporary workaround till Ray solves this
53            # issue: https://github.com/ray-project/ray/issues/40406
54            if count > MAX_ROWS_TO_TAKE:
55                # If count is larger than 100000, we take the first 100000 row, and then sample from that to
56                # maintain deterministic behaviour. We are using take_batch to get a pandas dataframe of size
57                # MAX_ROWS_TO_TAKE when the size of original dataset is greater than MAX_ROWS_TO_TAKE. This is to avoid
58                # failures in driver node by pulling too much data.
59                pandas_df = data.take_batch(batch_size=MAX_ROWS_TO_TAKE, batch_format="pandas")
60            else:
61                pandas_df = data.to_pandas()
62            sampled_df = pandas_df.sample(num_records, random_state=SEED)
63            data = ray.data.from_pandas(sampled_df)
64        data = data.repartition(get_num_actors() * PARTITION_MULTIPLIER).materialize()
65    return data

Util method to load Ray datasets using an input DataConfig.

Parameters
  • config: Input DataConfig
  • num_records: the number of records to sample from the dataset
def get_data_source(dataset_uri: str) -> fmeval.data_loaders.data_sources.DataSource:
124def get_data_source(dataset_uri: str) -> DataSource:
125    """
126    Validates a dataset URI and returns the corresponding DataSource object
127    :param dataset_uri: local dataset path or s3 dataset uri
128    :return: DataSource object
129    """
130    if _is_valid_local_path(dataset_uri):
131        return _get_local_data_source(dataset_uri)
132    elif _is_valid_s3_uri(dataset_uri):
133        return _get_s3_data_source(dataset_uri)
134    else:
135        raise EvalAlgorithmClientError(f"Invalid dataset path: {dataset_uri}")

Validates a dataset URI and returns the corresponding DataSource object

Parameters
  • dataset_uri: local dataset path or s3 dataset uri
Returns

DataSource object