fmeval.eval_algorithms.toxicity
1import logging 2from dataclasses import dataclass 3from typing import Optional, List, Union, Dict 4 5import ray 6from ray.actor import ActorHandle 7import numpy as np 8from fmeval.constants import DatasetColumns, MEAN 9from fmeval.data_loaders.data_config import DataConfig 10from fmeval.data_loaders.util import get_dataset 11from fmeval.eval_algorithms import ( 12 EvalOutput, 13 EvalScore, 14 EvalAlgorithm, 15) 16from fmeval.eval_algorithms.common import evaluate_dataset 17from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface, EvalAlgorithmConfig 18from fmeval.eval_algorithms.helper_models.helper_model import ToxigenHelperModel, DetoxifyHelperModel, BaseHelperModel 19from fmeval.eval_algorithms.save_strategy import SaveStrategy 20from fmeval.eval_algorithms.util import ( 21 get_dataset_configs, 22) 23from fmeval.transforms.batched_transform import BatchedTransform 24from fmeval.transforms.transform_pipeline import TransformPipeline 25from fmeval.transforms.util import validate_call 26from fmeval.util import get_eval_results_path, create_shared_resource, cleanup_shared_resource 27from fmeval.exceptions import EvalAlgorithmClientError 28from fmeval.model_runners.model_runner import ModelRunner 29 30TOXIGEN_MODEL = "toxigen" 31DETOXIFY_MODEL = "detoxify" 32DEFAULT_MODEL_TYPE = DETOXIFY_MODEL 33MODEL_TYPES_SUPPORTED = [TOXIGEN_MODEL, DETOXIFY_MODEL] 34TOXICITY_HELPER_MODEL_MAPPING = {TOXIGEN_MODEL: ToxigenHelperModel, DETOXIFY_MODEL: DetoxifyHelperModel} 35TOXICITY_BATCH_SIZE = 64 36 37logger = logging.getLogger(__name__) 38 39 40class ToxicityScores(BatchedTransform): 41 """This transform computes toxicity scores on a batch of records at a time using a helper model. 42 43 This transform augments the input batch with the computed scores. 44 """ 45 46 def __init__( 47 self, 48 input_key: str, 49 toxicity_helper_model: Union[ToxigenHelperModel, DetoxifyHelperModel, ActorHandle], 50 ): 51 """ToxicityScores initializer. 52 53 :param input_key: The key corresponding to the batch data to be processed by this transform. 54 :param toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED 55 for the supported helper models) or a Ray actor handle for a helper model. 56 """ 57 super().__init__(input_key, toxicity_helper_model) 58 score_names = ( 59 toxicity_helper_model.get_score_names() 60 if ( 61 isinstance(toxicity_helper_model, ToxigenHelperModel) 62 or isinstance(toxicity_helper_model, DetoxifyHelperModel) 63 ) 64 else ray.get(toxicity_helper_model.get_score_names.remote()) # type: ignore 65 ) 66 self.register_input_output_keys( 67 input_keys=[input_key], 68 output_keys=score_names, 69 ) 70 self.input_key = input_key 71 self.toxicity_helper_model = toxicity_helper_model 72 73 @property 74 def batch_size(self) -> int: 75 """The batch size to use when invoking the toxicity helper model.""" 76 return TOXICITY_BATCH_SIZE # pragma: no cover 77 78 @validate_call 79 def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 80 """Augment the input batch with toxicity scores computed by the helper model. 81 82 :param batch: The input batch. 83 :returns: The input batch with toxicity scores added in. 84 """ 85 text_input: List[str] = batch[self.input_key].tolist() 86 scores = ( 87 self.toxicity_helper_model.get_helper_scores(text_input) 88 if isinstance(self.toxicity_helper_model, BaseHelperModel) 89 else ray.get(self.toxicity_helper_model.get_helper_scores.remote(text_input)) 90 ) 91 for key, value in scores.items(): 92 batch.update({key: np.array(value)}) 93 return batch 94 95 96@dataclass(frozen=True) 97class ToxicityConfig(EvalAlgorithmConfig): 98 """ 99 Configuration for the toxicity eval algorithm 100 101 :param model_type: Which toxicity detector to use. Choose between "toxigen" and "detoxify". 102 """ 103 104 model_type: str = DEFAULT_MODEL_TYPE 105 106 def __post_init__(self): 107 if self.model_type not in MODEL_TYPES_SUPPORTED: 108 raise EvalAlgorithmClientError( 109 f"Invalid model_type: {self.model_type} requested in ToxicityConfig, " 110 f"please choose from acceptable values: {MODEL_TYPES_SUPPORTED}" 111 ) 112 113 114TOXICITY = EvalAlgorithm.TOXICITY.value 115 116 117class Toxicity(EvalAlgorithmInterface): 118 """ 119 This evaluation measures whether a model outputs toxic content, and it can be performed over any task that involves the generation of content (including open-ended generation, summarization and question answering). The toxicity score is given by one of two built-in toxicity detectors, "toxigen" and "detoxify". Configure which one to use inside the `ToxicityConfig`. 120 121 Disclaimer: the concept of toxicity is cultural and context dependent. As this evaluation employs a model to score generated passages, the various scores represent the “view” of the toxicity detector used. 122 """ 123 124 eval_name = TOXICITY 125 126 def __init__(self, eval_algorithm_config: ToxicityConfig = ToxicityConfig()): 127 """Toxicity initializer. 128 129 :param eval_algorithm_config: Toxicity evaluation algorithm config. 130 """ 131 super().__init__(eval_algorithm_config) 132 self._helper_model = TOXICITY_HELPER_MODEL_MAPPING[eval_algorithm_config.model_type]() 133 134 def evaluate_sample(self, model_output: str) -> List[EvalScore]: # type: ignore[override] 135 """Evaluate toxicity on a single datapoint. 136 137 :param model_output: The output of the model under evaluation. 138 :returns: A list of EvalScore objects representing the computed toxicity scores. 139 """ 140 scores = self._helper_model.get_helper_scores([model_output]) 141 return [EvalScore(name=key, value=value[0]) for key, value in scores.items()] 142 143 def evaluate( 144 self, 145 model: Optional[ModelRunner] = None, 146 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 147 prompt_template: Optional[str] = None, 148 num_records: int = 100, 149 save: bool = False, 150 save_strategy: Optional[SaveStrategy] = None, 151 ) -> List[EvalOutput]: 152 """Compute toxicity metrics on one or more datasets. 153 154 :param model: An instance of ModelRunner representing the model under evaluation. 155 If this argument is None, the `dataset_config` argument must not be None, 156 and must correspond to a dataset that already contains a column with model outputs. 157 :param dataset_config: Configures a single dataset or list of datasets used for the 158 evaluation. If not provided, this method will run evaluations using all of its 159 supported built-in datasets. 160 :param prompt_template: A template used to generate prompts that are fed to the model. 161 If not provided, defaults will be used. If provided, `model` must not be None. 162 :param num_records: The number of records to be sampled randomly from the input dataset(s) 163 used to perform the evaluation(s). 164 :param save: If set to true, prompt responses and scores will be saved to a file. 165 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 166 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 167 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 168 169 :return: A list of EvalOutput objects. 170 """ 171 toxicity_model_shared_resource = create_shared_resource(self._helper_model) 172 pipeline = TransformPipeline( 173 [ 174 ToxicityScores( 175 input_key=DatasetColumns.MODEL_OUTPUT.value.name, 176 toxicity_helper_model=toxicity_model_shared_resource, 177 ) 178 ] 179 ) 180 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 181 eval_outputs = [] 182 for dataset_config in dataset_configs: 183 dataset = get_dataset(dataset_config, num_records) 184 eval_output = evaluate_dataset( 185 dataset=dataset, 186 pipeline=pipeline, 187 dataset_name=dataset_config.dataset_name, 188 eval_name=self.eval_name, 189 metric_names=self._helper_model.get_score_names(), 190 eval_results_path=get_eval_results_path(), 191 model=model, 192 prompt_template=prompt_template, 193 agg_method=MEAN, 194 save=save, 195 save_strategy=save_strategy, 196 ) 197 eval_outputs.append(eval_output) 198 cleanup_shared_resource(toxicity_model_shared_resource) 199 return eval_outputs
41class ToxicityScores(BatchedTransform): 42 """This transform computes toxicity scores on a batch of records at a time using a helper model. 43 44 This transform augments the input batch with the computed scores. 45 """ 46 47 def __init__( 48 self, 49 input_key: str, 50 toxicity_helper_model: Union[ToxigenHelperModel, DetoxifyHelperModel, ActorHandle], 51 ): 52 """ToxicityScores initializer. 53 54 :param input_key: The key corresponding to the batch data to be processed by this transform. 55 :param toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED 56 for the supported helper models) or a Ray actor handle for a helper model. 57 """ 58 super().__init__(input_key, toxicity_helper_model) 59 score_names = ( 60 toxicity_helper_model.get_score_names() 61 if ( 62 isinstance(toxicity_helper_model, ToxigenHelperModel) 63 or isinstance(toxicity_helper_model, DetoxifyHelperModel) 64 ) 65 else ray.get(toxicity_helper_model.get_score_names.remote()) # type: ignore 66 ) 67 self.register_input_output_keys( 68 input_keys=[input_key], 69 output_keys=score_names, 70 ) 71 self.input_key = input_key 72 self.toxicity_helper_model = toxicity_helper_model 73 74 @property 75 def batch_size(self) -> int: 76 """The batch size to use when invoking the toxicity helper model.""" 77 return TOXICITY_BATCH_SIZE # pragma: no cover 78 79 @validate_call 80 def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 81 """Augment the input batch with toxicity scores computed by the helper model. 82 83 :param batch: The input batch. 84 :returns: The input batch with toxicity scores added in. 85 """ 86 text_input: List[str] = batch[self.input_key].tolist() 87 scores = ( 88 self.toxicity_helper_model.get_helper_scores(text_input) 89 if isinstance(self.toxicity_helper_model, BaseHelperModel) 90 else ray.get(self.toxicity_helper_model.get_helper_scores.remote(text_input)) 91 ) 92 for key, value in scores.items(): 93 batch.update({key: np.array(value)}) 94 return batch
This transform computes toxicity scores on a batch of records at a time using a helper model.
This transform augments the input batch with the computed scores.
47 def __init__( 48 self, 49 input_key: str, 50 toxicity_helper_model: Union[ToxigenHelperModel, DetoxifyHelperModel, ActorHandle], 51 ): 52 """ToxicityScores initializer. 53 54 :param input_key: The key corresponding to the batch data to be processed by this transform. 55 :param toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED 56 for the supported helper models) or a Ray actor handle for a helper model. 57 """ 58 super().__init__(input_key, toxicity_helper_model) 59 score_names = ( 60 toxicity_helper_model.get_score_names() 61 if ( 62 isinstance(toxicity_helper_model, ToxigenHelperModel) 63 or isinstance(toxicity_helper_model, DetoxifyHelperModel) 64 ) 65 else ray.get(toxicity_helper_model.get_score_names.remote()) # type: ignore 66 ) 67 self.register_input_output_keys( 68 input_keys=[input_key], 69 output_keys=score_names, 70 ) 71 self.input_key = input_key 72 self.toxicity_helper_model = toxicity_helper_model
ToxicityScores initializer.
Parameters
- input_key: The key corresponding to the batch data to be processed by this transform.
- toxicity_helper_model: A toxicity helper model instance (see MODEL_TYPES_SUPPORTED for the supported helper models) or a Ray actor handle for a helper model.
97@dataclass(frozen=True) 98class ToxicityConfig(EvalAlgorithmConfig): 99 """ 100 Configuration for the toxicity eval algorithm 101 102 :param model_type: Which toxicity detector to use. Choose between "toxigen" and "detoxify". 103 """ 104 105 model_type: str = DEFAULT_MODEL_TYPE 106 107 def __post_init__(self): 108 if self.model_type not in MODEL_TYPES_SUPPORTED: 109 raise EvalAlgorithmClientError( 110 f"Invalid model_type: {self.model_type} requested in ToxicityConfig, " 111 f"please choose from acceptable values: {MODEL_TYPES_SUPPORTED}" 112 )
Configuration for the toxicity eval algorithm
Parameters
- model_type: Which toxicity detector to use. Choose between "toxigen" and "detoxify".
118class Toxicity(EvalAlgorithmInterface): 119 """ 120 This evaluation measures whether a model outputs toxic content, and it can be performed over any task that involves the generation of content (including open-ended generation, summarization and question answering). The toxicity score is given by one of two built-in toxicity detectors, "toxigen" and "detoxify". Configure which one to use inside the `ToxicityConfig`. 121 122 Disclaimer: the concept of toxicity is cultural and context dependent. As this evaluation employs a model to score generated passages, the various scores represent the “view” of the toxicity detector used. 123 """ 124 125 eval_name = TOXICITY 126 127 def __init__(self, eval_algorithm_config: ToxicityConfig = ToxicityConfig()): 128 """Toxicity initializer. 129 130 :param eval_algorithm_config: Toxicity evaluation algorithm config. 131 """ 132 super().__init__(eval_algorithm_config) 133 self._helper_model = TOXICITY_HELPER_MODEL_MAPPING[eval_algorithm_config.model_type]() 134 135 def evaluate_sample(self, model_output: str) -> List[EvalScore]: # type: ignore[override] 136 """Evaluate toxicity on a single datapoint. 137 138 :param model_output: The output of the model under evaluation. 139 :returns: A list of EvalScore objects representing the computed toxicity scores. 140 """ 141 scores = self._helper_model.get_helper_scores([model_output]) 142 return [EvalScore(name=key, value=value[0]) for key, value in scores.items()] 143 144 def evaluate( 145 self, 146 model: Optional[ModelRunner] = None, 147 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 148 prompt_template: Optional[str] = None, 149 num_records: int = 100, 150 save: bool = False, 151 save_strategy: Optional[SaveStrategy] = None, 152 ) -> List[EvalOutput]: 153 """Compute toxicity metrics on one or more datasets. 154 155 :param model: An instance of ModelRunner representing the model under evaluation. 156 If this argument is None, the `dataset_config` argument must not be None, 157 and must correspond to a dataset that already contains a column with model outputs. 158 :param dataset_config: Configures a single dataset or list of datasets used for the 159 evaluation. If not provided, this method will run evaluations using all of its 160 supported built-in datasets. 161 :param prompt_template: A template used to generate prompts that are fed to the model. 162 If not provided, defaults will be used. If provided, `model` must not be None. 163 :param num_records: The number of records to be sampled randomly from the input dataset(s) 164 used to perform the evaluation(s). 165 :param save: If set to true, prompt responses and scores will be saved to a file. 166 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 167 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 168 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 169 170 :return: A list of EvalOutput objects. 171 """ 172 toxicity_model_shared_resource = create_shared_resource(self._helper_model) 173 pipeline = TransformPipeline( 174 [ 175 ToxicityScores( 176 input_key=DatasetColumns.MODEL_OUTPUT.value.name, 177 toxicity_helper_model=toxicity_model_shared_resource, 178 ) 179 ] 180 ) 181 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 182 eval_outputs = [] 183 for dataset_config in dataset_configs: 184 dataset = get_dataset(dataset_config, num_records) 185 eval_output = evaluate_dataset( 186 dataset=dataset, 187 pipeline=pipeline, 188 dataset_name=dataset_config.dataset_name, 189 eval_name=self.eval_name, 190 metric_names=self._helper_model.get_score_names(), 191 eval_results_path=get_eval_results_path(), 192 model=model, 193 prompt_template=prompt_template, 194 agg_method=MEAN, 195 save=save, 196 save_strategy=save_strategy, 197 ) 198 eval_outputs.append(eval_output) 199 cleanup_shared_resource(toxicity_model_shared_resource) 200 return eval_outputs
This evaluation measures whether a model outputs toxic content, and it can be performed over any task that involves the generation of content (including open-ended generation, summarization and question answering). The toxicity score is given by one of two built-in toxicity detectors, "toxigen" and "detoxify". Configure which one to use inside the ToxicityConfig
.
Disclaimer: the concept of toxicity is cultural and context dependent. As this evaluation employs a model to score generated passages, the various scores represent the “view” of the toxicity detector used.
127 def __init__(self, eval_algorithm_config: ToxicityConfig = ToxicityConfig()): 128 """Toxicity initializer. 129 130 :param eval_algorithm_config: Toxicity evaluation algorithm config. 131 """ 132 super().__init__(eval_algorithm_config) 133 self._helper_model = TOXICITY_HELPER_MODEL_MAPPING[eval_algorithm_config.model_type]()
Toxicity initializer.
Parameters
- eval_algorithm_config: Toxicity evaluation algorithm config.
135 def evaluate_sample(self, model_output: str) -> List[EvalScore]: # type: ignore[override] 136 """Evaluate toxicity on a single datapoint. 137 138 :param model_output: The output of the model under evaluation. 139 :returns: A list of EvalScore objects representing the computed toxicity scores. 140 """ 141 scores = self._helper_model.get_helper_scores([model_output]) 142 return [EvalScore(name=key, value=value[0]) for key, value in scores.items()]
Evaluate toxicity on a single datapoint.
Parameters
- model_output: The output of the model under evaluation. :returns: A list of EvalScore objects representing the computed toxicity scores.
144 def evaluate( 145 self, 146 model: Optional[ModelRunner] = None, 147 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 148 prompt_template: Optional[str] = None, 149 num_records: int = 100, 150 save: bool = False, 151 save_strategy: Optional[SaveStrategy] = None, 152 ) -> List[EvalOutput]: 153 """Compute toxicity metrics on one or more datasets. 154 155 :param model: An instance of ModelRunner representing the model under evaluation. 156 If this argument is None, the `dataset_config` argument must not be None, 157 and must correspond to a dataset that already contains a column with model outputs. 158 :param dataset_config: Configures a single dataset or list of datasets used for the 159 evaluation. If not provided, this method will run evaluations using all of its 160 supported built-in datasets. 161 :param prompt_template: A template used to generate prompts that are fed to the model. 162 If not provided, defaults will be used. If provided, `model` must not be None. 163 :param num_records: The number of records to be sampled randomly from the input dataset(s) 164 used to perform the evaluation(s). 165 :param save: If set to true, prompt responses and scores will be saved to a file. 166 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 167 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 168 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 169 170 :return: A list of EvalOutput objects. 171 """ 172 toxicity_model_shared_resource = create_shared_resource(self._helper_model) 173 pipeline = TransformPipeline( 174 [ 175 ToxicityScores( 176 input_key=DatasetColumns.MODEL_OUTPUT.value.name, 177 toxicity_helper_model=toxicity_model_shared_resource, 178 ) 179 ] 180 ) 181 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 182 eval_outputs = [] 183 for dataset_config in dataset_configs: 184 dataset = get_dataset(dataset_config, num_records) 185 eval_output = evaluate_dataset( 186 dataset=dataset, 187 pipeline=pipeline, 188 dataset_name=dataset_config.dataset_name, 189 eval_name=self.eval_name, 190 metric_names=self._helper_model.get_score_names(), 191 eval_results_path=get_eval_results_path(), 192 model=model, 193 prompt_template=prompt_template, 194 agg_method=MEAN, 195 save=save, 196 save_strategy=save_strategy, 197 ) 198 eval_outputs.append(eval_output) 199 cleanup_shared_resource(toxicity_model_shared_resource) 200 return eval_outputs
Compute toxicity metrics on one or more datasets.
Parameters
- model: An instance of ModelRunner representing the model under evaluation.
If this argument is None, the
dataset_config
argument must not be None, and must correspond to a dataset that already contains a column with model outputs. - dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
- prompt_template: A template used to generate prompts that are fed to the model.
If not provided, defaults will be used. If provided,
model
must not be None. - num_records: The number of records to be sampled randomly from the input dataset(s) used to perform the evaluation(s).
- save: If set to true, prompt responses and scores will be saved to a file.
- save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
If that environment variable is also not configured, it will be saved to the default path
/tmp/eval_results/
.
Returns
A list of EvalOutput objects.