fmeval.eval_algorithms.toxicity

  1import logging
  2from dataclasses import dataclass
  3from typing import Optional, List, Union, Dict
  4
  5import ray
  6from ray.actor import ActorHandle
  7import numpy as np
  8from fmeval.constants import DatasetColumns, MEAN
  9from fmeval.data_loaders.data_config import DataConfig
 10from fmeval.data_loaders.util import get_dataset
 11from fmeval.eval_algorithms import (
 12    EvalOutput,
 13    EvalScore,
 14    EvalAlgorithm,
 15)
 16from fmeval.eval_algorithms.common import evaluate_dataset
 17from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface, EvalAlgorithmConfig
 18from fmeval.eval_algorithms.helper_models.helper_model import ToxigenHelperModel, DetoxifyHelperModel, BaseHelperModel
 19from fmeval.eval_algorithms.save_strategy import SaveStrategy
 20from fmeval.eval_algorithms.util import (
 21    get_dataset_configs,
 22)
 23from fmeval.transforms.batched_transform import BatchedTransform
 24from fmeval.transforms.transform_pipeline import TransformPipeline
 25from fmeval.transforms.util import validate_call
 26from fmeval.util import get_eval_results_path, create_shared_resource, cleanup_shared_resource
 27from fmeval.exceptions import EvalAlgorithmClientError
 28from fmeval.model_runners.model_runner import ModelRunner
 29
 30TOXIGEN_MODEL = "toxigen"
 31DETOXIFY_MODEL = "detoxify"
 32DEFAULT_MODEL_TYPE = DETOXIFY_MODEL
 33MODEL_TYPES_SUPPORTED = [TOXIGEN_MODEL, DETOXIFY_MODEL]
 34TOXICITY_HELPER_MODEL_MAPPING = {TOXIGEN_MODEL: ToxigenHelperModel, DETOXIFY_MODEL: DetoxifyHelperModel}
 35TOXICITY_BATCH_SIZE = 64
 36
 37logger = logging.getLogger(__name__)
 38
 39
 40class ToxicityScores(BatchedTransform):
 41    """This transform computes toxicity scores on a batch of records at a time using a helper model.
 42
 43    This transform augments the input batch with the computed scores.
 44    """
 45
 46    def __init__(
 47        self,
 48        input_key: str,
 49        toxicity_helper_model: Union[ToxigenHelperModel, DetoxifyHelperModel, ActorHandle],
 50    ):
 51        """ToxicityScores initializer.
 52
 53        :param input_key: The key corresponding to the batch data to be processed by this transform.
 54        :param toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED
 55            for the supported helper models) or a Ray actor handle for a helper model.
 56        """
 57        super().__init__(input_key, toxicity_helper_model)
 58        score_names = (
 59            toxicity_helper_model.get_score_names()
 60            if (
 61                isinstance(toxicity_helper_model, ToxigenHelperModel)
 62                or isinstance(toxicity_helper_model, DetoxifyHelperModel)
 63            )
 64            else ray.get(toxicity_helper_model.get_score_names.remote())  # type: ignore
 65        )
 66        self.register_input_output_keys(
 67            input_keys=[input_key],
 68            output_keys=score_names,
 69        )
 70        self.input_key = input_key
 71        self.toxicity_helper_model = toxicity_helper_model
 72
 73    @property
 74    def batch_size(self) -> int:
 75        """The batch size to use when invoking the toxicity helper model."""
 76        return TOXICITY_BATCH_SIZE  # pragma: no cover
 77
 78    @validate_call
 79    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
 80        """Augment the input batch with toxicity scores computed by the helper model.
 81
 82        :param batch: The input batch.
 83        :returns: The input batch with toxicity scores added in.
 84        """
 85        text_input: List[str] = batch[self.input_key].tolist()
 86        scores = (
 87            self.toxicity_helper_model.get_helper_scores(text_input)
 88            if isinstance(self.toxicity_helper_model, BaseHelperModel)
 89            else ray.get(self.toxicity_helper_model.get_helper_scores.remote(text_input))
 90        )
 91        for key, value in scores.items():
 92            batch.update({key: np.array(value)})
 93        return batch
 94
 95
 96@dataclass(frozen=True)
 97class ToxicityConfig(EvalAlgorithmConfig):
 98    """
 99    Configuration for the toxicity eval algorithm
100
101    :param model_type: Which toxicity detector to use. Choose between "toxigen" and "detoxify".
102    """
103
104    model_type: str = DEFAULT_MODEL_TYPE
105
106    def __post_init__(self):
107        if self.model_type not in MODEL_TYPES_SUPPORTED:
108            raise EvalAlgorithmClientError(
109                f"Invalid model_type: {self.model_type} requested in ToxicityConfig, "
110                f"please choose from acceptable values: {MODEL_TYPES_SUPPORTED}"
111            )
112
113
114TOXICITY = EvalAlgorithm.TOXICITY.value
115
116
117class Toxicity(EvalAlgorithmInterface):
118    """
119    This evaluation measures whether a model outputs toxic content, and it can be performed over any task that involves the generation of content (including open-ended generation, summarization and question answering). The toxicity score is given by one of two built-in toxicity detectors, "toxigen" and "detoxify". Configure which one to use inside the `ToxicityConfig`.
120
121    Disclaimer: the concept of toxicity is cultural and context dependent. As this evaluation employs a model to score generated passages, the various scores represent the “view” of the toxicity detector used.
122    """
123
124    eval_name = TOXICITY
125
126    def __init__(self, eval_algorithm_config: ToxicityConfig = ToxicityConfig()):
127        """Toxicity initializer.
128
129        :param eval_algorithm_config: Toxicity evaluation algorithm config.
130        """
131        super().__init__(eval_algorithm_config)
132        self._helper_model = TOXICITY_HELPER_MODEL_MAPPING[eval_algorithm_config.model_type]()
133
134    def evaluate_sample(self, model_output: str) -> List[EvalScore]:  # type: ignore[override]
135        """Evaluate toxicity on a single datapoint.
136
137        :param model_output: The output of the model under evaluation.
138        :returns: A list of EvalScore objects representing the computed toxicity scores.
139        """
140        scores = self._helper_model.get_helper_scores([model_output])
141        return [EvalScore(name=key, value=value[0]) for key, value in scores.items()]
142
143    def evaluate(
144        self,
145        model: Optional[ModelRunner] = None,
146        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
147        prompt_template: Optional[str] = None,
148        num_records: int = 100,
149        save: bool = False,
150        save_strategy: Optional[SaveStrategy] = None,
151    ) -> List[EvalOutput]:
152        """Compute toxicity metrics on one or more datasets.
153
154        :param model: An instance of ModelRunner representing the model under evaluation.
155            If this argument is None, the `dataset_config` argument must not be None,
156            and must correspond to a dataset that already contains a column with model outputs.
157        :param dataset_config: Configures a single dataset or list of datasets used for the
158            evaluation. If not provided, this method will run evaluations using all of its
159            supported built-in datasets.
160        :param prompt_template: A template used to generate prompts that are fed to the model.
161            If not provided, defaults will be used. If provided, `model` must not be None.
162        :param num_records: The number of records to be sampled randomly from the input dataset(s)
163            used to perform the evaluation(s).
164        :param save: If set to true, prompt responses and scores will be saved to a file.
165        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
166            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
167            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
168
169        :return: A list of EvalOutput objects.
170        """
171        toxicity_model_shared_resource = create_shared_resource(self._helper_model)
172        pipeline = TransformPipeline(
173            [
174                ToxicityScores(
175                    input_key=DatasetColumns.MODEL_OUTPUT.value.name,
176                    toxicity_helper_model=toxicity_model_shared_resource,
177                )
178            ]
179        )
180        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
181        eval_outputs = []
182        for dataset_config in dataset_configs:
183            dataset = get_dataset(dataset_config, num_records)
184            eval_output = evaluate_dataset(
185                dataset=dataset,
186                pipeline=pipeline,
187                dataset_name=dataset_config.dataset_name,
188                eval_name=self.eval_name,
189                metric_names=self._helper_model.get_score_names(),
190                eval_results_path=get_eval_results_path(),
191                model=model,
192                prompt_template=prompt_template,
193                agg_method=MEAN,
194                save=save,
195                save_strategy=save_strategy,
196            )
197            eval_outputs.append(eval_output)
198        cleanup_shared_resource(toxicity_model_shared_resource)
199        return eval_outputs
TOXIGEN_MODEL = 'toxigen'
DETOXIFY_MODEL = 'detoxify'
DEFAULT_MODEL_TYPE = 'detoxify'
MODEL_TYPES_SUPPORTED = ['toxigen', 'detoxify']
TOXICITY_BATCH_SIZE = 64
logger = <Logger fmeval.eval_algorithms.toxicity (WARNING)>
class ToxicityScores(fmeval.transforms.batched_transform.BatchedTransform):
41class ToxicityScores(BatchedTransform):
42    """This transform computes toxicity scores on a batch of records at a time using a helper model.
43
44    This transform augments the input batch with the computed scores.
45    """
46
47    def __init__(
48        self,
49        input_key: str,
50        toxicity_helper_model: Union[ToxigenHelperModel, DetoxifyHelperModel, ActorHandle],
51    ):
52        """ToxicityScores initializer.
53
54        :param input_key: The key corresponding to the batch data to be processed by this transform.
55        :param toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED
56            for the supported helper models) or a Ray actor handle for a helper model.
57        """
58        super().__init__(input_key, toxicity_helper_model)
59        score_names = (
60            toxicity_helper_model.get_score_names()
61            if (
62                isinstance(toxicity_helper_model, ToxigenHelperModel)
63                or isinstance(toxicity_helper_model, DetoxifyHelperModel)
64            )
65            else ray.get(toxicity_helper_model.get_score_names.remote())  # type: ignore
66        )
67        self.register_input_output_keys(
68            input_keys=[input_key],
69            output_keys=score_names,
70        )
71        self.input_key = input_key
72        self.toxicity_helper_model = toxicity_helper_model
73
74    @property
75    def batch_size(self) -> int:
76        """The batch size to use when invoking the toxicity helper model."""
77        return TOXICITY_BATCH_SIZE  # pragma: no cover
78
79    @validate_call
80    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
81        """Augment the input batch with toxicity scores computed by the helper model.
82
83        :param batch: The input batch.
84        :returns: The input batch with toxicity scores added in.
85        """
86        text_input: List[str] = batch[self.input_key].tolist()
87        scores = (
88            self.toxicity_helper_model.get_helper_scores(text_input)
89            if isinstance(self.toxicity_helper_model, BaseHelperModel)
90            else ray.get(self.toxicity_helper_model.get_helper_scores.remote(text_input))
91        )
92        for key, value in scores.items():
93            batch.update({key: np.array(value)})
94        return batch

This transform computes toxicity scores on a batch of records at a time using a helper model.

This transform augments the input batch with the computed scores.

ToxicityScores( input_key: str, toxicity_helper_model: Union[fmeval.eval_algorithms.helper_models.helper_model.ToxigenHelperModel, fmeval.eval_algorithms.helper_models.helper_model.DetoxifyHelperModel, ray.actor.ActorHandle])
47    def __init__(
48        self,
49        input_key: str,
50        toxicity_helper_model: Union[ToxigenHelperModel, DetoxifyHelperModel, ActorHandle],
51    ):
52        """ToxicityScores initializer.
53
54        :param input_key: The key corresponding to the batch data to be processed by this transform.
55        :param toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED
56            for the supported helper models) or a Ray actor handle for a helper model.
57        """
58        super().__init__(input_key, toxicity_helper_model)
59        score_names = (
60            toxicity_helper_model.get_score_names()
61            if (
62                isinstance(toxicity_helper_model, ToxigenHelperModel)
63                or isinstance(toxicity_helper_model, DetoxifyHelperModel)
64            )
65            else ray.get(toxicity_helper_model.get_score_names.remote())  # type: ignore
66        )
67        self.register_input_output_keys(
68            input_keys=[input_key],
69            output_keys=score_names,
70        )
71        self.input_key = input_key
72        self.toxicity_helper_model = toxicity_helper_model

ToxicityScores initializer.

Parameters
  • input_key: The key corresponding to the batch data to be processed by this transform.
  • toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED for the supported helper models) or a Ray actor handle for a helper model.
input_key
toxicity_helper_model
batch_size: int
74    @property
75    def batch_size(self) -> int:
76        """The batch size to use when invoking the toxicity helper model."""
77        return TOXICITY_BATCH_SIZE  # pragma: no cover

The batch size to use when invoking the toxicity helper model.

@dataclass(frozen=True)
class ToxicityConfig(fmeval.eval_algorithms.eval_algorithm.EvalAlgorithmConfig):
 97@dataclass(frozen=True)
 98class ToxicityConfig(EvalAlgorithmConfig):
 99    """
100    Configuration for the toxicity eval algorithm
101
102    :param model_type: Which toxicity detector to use. Choose between "toxigen" and "detoxify".
103    """
104
105    model_type: str = DEFAULT_MODEL_TYPE
106
107    def __post_init__(self):
108        if self.model_type not in MODEL_TYPES_SUPPORTED:
109            raise EvalAlgorithmClientError(
110                f"Invalid model_type: {self.model_type} requested in ToxicityConfig, "
111                f"please choose from acceptable values: {MODEL_TYPES_SUPPORTED}"
112            )

Configuration for the toxicity eval algorithm

Parameters
  • model_type: Which toxicity detector to use. Choose between "toxigen" and "detoxify".
ToxicityConfig(model_type: str = 'detoxify')
model_type: str = 'detoxify'
TOXICITY = 'toxicity'
118class Toxicity(EvalAlgorithmInterface):
119    """
120    This evaluation measures whether a model outputs toxic content, and it can be performed over any task that involves the generation of content (including open-ended generation, summarization and question answering). The toxicity score is given by one of two built-in toxicity detectors, "toxigen" and "detoxify". Configure which one to use inside the `ToxicityConfig`.
121
122    Disclaimer: the concept of toxicity is cultural and context dependent. As this evaluation employs a model to score generated passages, the various scores represent the “view” of the toxicity detector used.
123    """
124
125    eval_name = TOXICITY
126
127    def __init__(self, eval_algorithm_config: ToxicityConfig = ToxicityConfig()):
128        """Toxicity initializer.
129
130        :param eval_algorithm_config: Toxicity evaluation algorithm config.
131        """
132        super().__init__(eval_algorithm_config)
133        self._helper_model = TOXICITY_HELPER_MODEL_MAPPING[eval_algorithm_config.model_type]()
134
135    def evaluate_sample(self, model_output: str) -> List[EvalScore]:  # type: ignore[override]
136        """Evaluate toxicity on a single datapoint.
137
138        :param model_output: The output of the model under evaluation.
139        :returns: A list of EvalScore objects representing the computed toxicity scores.
140        """
141        scores = self._helper_model.get_helper_scores([model_output])
142        return [EvalScore(name=key, value=value[0]) for key, value in scores.items()]
143
144    def evaluate(
145        self,
146        model: Optional[ModelRunner] = None,
147        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
148        prompt_template: Optional[str] = None,
149        num_records: int = 100,
150        save: bool = False,
151        save_strategy: Optional[SaveStrategy] = None,
152    ) -> List[EvalOutput]:
153        """Compute toxicity metrics on one or more datasets.
154
155        :param model: An instance of ModelRunner representing the model under evaluation.
156            If this argument is None, the `dataset_config` argument must not be None,
157            and must correspond to a dataset that already contains a column with model outputs.
158        :param dataset_config: Configures a single dataset or list of datasets used for the
159            evaluation. If not provided, this method will run evaluations using all of its
160            supported built-in datasets.
161        :param prompt_template: A template used to generate prompts that are fed to the model.
162            If not provided, defaults will be used. If provided, `model` must not be None.
163        :param num_records: The number of records to be sampled randomly from the input dataset(s)
164            used to perform the evaluation(s).
165        :param save: If set to true, prompt responses and scores will be saved to a file.
166        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
167            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
168            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
169
170        :return: A list of EvalOutput objects.
171        """
172        toxicity_model_shared_resource = create_shared_resource(self._helper_model)
173        pipeline = TransformPipeline(
174            [
175                ToxicityScores(
176                    input_key=DatasetColumns.MODEL_OUTPUT.value.name,
177                    toxicity_helper_model=toxicity_model_shared_resource,
178                )
179            ]
180        )
181        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
182        eval_outputs = []
183        for dataset_config in dataset_configs:
184            dataset = get_dataset(dataset_config, num_records)
185            eval_output = evaluate_dataset(
186                dataset=dataset,
187                pipeline=pipeline,
188                dataset_name=dataset_config.dataset_name,
189                eval_name=self.eval_name,
190                metric_names=self._helper_model.get_score_names(),
191                eval_results_path=get_eval_results_path(),
192                model=model,
193                prompt_template=prompt_template,
194                agg_method=MEAN,
195                save=save,
196                save_strategy=save_strategy,
197            )
198            eval_outputs.append(eval_output)
199        cleanup_shared_resource(toxicity_model_shared_resource)
200        return eval_outputs

This evaluation measures whether a model outputs toxic content, and it can be performed over any task that involves the generation of content (including open-ended generation, summarization and question answering). The toxicity score is given by one of two built-in toxicity detectors, "toxigen" and "detoxify". Configure which one to use inside the ToxicityConfig.

Disclaimer: the concept of toxicity is cultural and context dependent. As this evaluation employs a model to score generated passages, the various scores represent the “view” of the toxicity detector used.

Toxicity( eval_algorithm_config: ToxicityConfig = ToxicityConfig(model_type='detoxify'))
127    def __init__(self, eval_algorithm_config: ToxicityConfig = ToxicityConfig()):
128        """Toxicity initializer.
129
130        :param eval_algorithm_config: Toxicity evaluation algorithm config.
131        """
132        super().__init__(eval_algorithm_config)
133        self._helper_model = TOXICITY_HELPER_MODEL_MAPPING[eval_algorithm_config.model_type]()

Toxicity initializer.

Parameters
  • eval_algorithm_config: Toxicity evaluation algorithm config.
eval_name = 'toxicity'
def evaluate_sample(self, model_output: str) -> List[fmeval.eval_algorithms.EvalScore]:
135    def evaluate_sample(self, model_output: str) -> List[EvalScore]:  # type: ignore[override]
136        """Evaluate toxicity on a single datapoint.
137
138        :param model_output: The output of the model under evaluation.
139        :returns: A list of EvalScore objects representing the computed toxicity scores.
140        """
141        scores = self._helper_model.get_helper_scores([model_output])
142        return [EvalScore(name=key, value=value[0]) for key, value in scores.items()]

Evaluate toxicity on a single datapoint.

Parameters
  • model_output: The output of the model under evaluation. :returns: A list of EvalScore objects representing the computed toxicity scores.
def evaluate( self, model: Optional[fmeval.model_runners.model_runner.ModelRunner] = None, dataset_config: Union[fmeval.data_loaders.data_config.DataConfig, List[fmeval.data_loaders.data_config.DataConfig], NoneType] = None, prompt_template: Optional[str] = None, num_records: int = 100, save: bool = False, save_strategy: Optional[fmeval.eval_algorithms.save_strategy.SaveStrategy] = None) -> List[fmeval.eval_algorithms.EvalOutput]:
144    def evaluate(
145        self,
146        model: Optional[ModelRunner] = None,
147        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
148        prompt_template: Optional[str] = None,
149        num_records: int = 100,
150        save: bool = False,
151        save_strategy: Optional[SaveStrategy] = None,
152    ) -> List[EvalOutput]:
153        """Compute toxicity metrics on one or more datasets.
154
155        :param model: An instance of ModelRunner representing the model under evaluation.
156            If this argument is None, the `dataset_config` argument must not be None,
157            and must correspond to a dataset that already contains a column with model outputs.
158        :param dataset_config: Configures a single dataset or list of datasets used for the
159            evaluation. If not provided, this method will run evaluations using all of its
160            supported built-in datasets.
161        :param prompt_template: A template used to generate prompts that are fed to the model.
162            If not provided, defaults will be used. If provided, `model` must not be None.
163        :param num_records: The number of records to be sampled randomly from the input dataset(s)
164            used to perform the evaluation(s).
165        :param save: If set to true, prompt responses and scores will be saved to a file.
166        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
167            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
168            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
169
170        :return: A list of EvalOutput objects.
171        """
172        toxicity_model_shared_resource = create_shared_resource(self._helper_model)
173        pipeline = TransformPipeline(
174            [
175                ToxicityScores(
176                    input_key=DatasetColumns.MODEL_OUTPUT.value.name,
177                    toxicity_helper_model=toxicity_model_shared_resource,
178                )
179            ]
180        )
181        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
182        eval_outputs = []
183        for dataset_config in dataset_configs:
184            dataset = get_dataset(dataset_config, num_records)
185            eval_output = evaluate_dataset(
186                dataset=dataset,
187                pipeline=pipeline,
188                dataset_name=dataset_config.dataset_name,
189                eval_name=self.eval_name,
190                metric_names=self._helper_model.get_score_names(),
191                eval_results_path=get_eval_results_path(),
192                model=model,
193                prompt_template=prompt_template,
194                agg_method=MEAN,
195                save=save,
196                save_strategy=save_strategy,
197            )
198            eval_outputs.append(eval_output)
199        cleanup_shared_resource(toxicity_model_shared_resource)
200        return eval_outputs

Compute toxicity metrics on one or more datasets.

Parameters
  • model: An instance of ModelRunner representing the model under evaluation. If this argument is None, the dataset_config argument must not be None, and must correspond to a dataset that already contains a column with model outputs.
  • dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
  • prompt_template: A template used to generate prompts that are fed to the model. If not provided, defaults will be used. If provided, model must not be None.
  • num_records: The number of records to be sampled randomly from the input dataset(s) used to perform the evaluation(s).
  • save: If set to true, prompt responses and scores will be saved to a file.
  • save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. If that environment variable is also not configured, it will be saved to the default path /tmp/eval_results/.
Returns

A list of EvalOutput objects.