fmeval.eval_algorithms.general_semantic_robustness
1import itertools 2import logging 3from dataclasses import dataclass 4from typing import Any, Dict, List, Optional, Union 5 6from fmeval.constants import ( 7 DatasetColumns, 8 MEAN, 9) 10from fmeval.data_loaders.data_config import DataConfig 11from fmeval.data_loaders.util import get_dataset 12from fmeval.eval_algorithms import ( 13 EvalAlgorithm, 14 EvalScore, 15 EvalOutput, 16 DEFAULT_PROMPT_TEMPLATE, 17 get_default_prompt_template, 18) 19from fmeval.eval_algorithms.common import evaluate_dataset 20from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface 21from fmeval.eval_algorithms.save_strategy import SaveStrategy 22from fmeval.eval_algorithms.semantic_robustness_utils import ( 23 SemanticRobustnessConfig, 24 get_perturbation_transform, 25 get_model_outputs_from_perturbed_inputs, 26) 27from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModelTypes, BertscoreHelperModel 28from fmeval.transforms.common import GeneratePrompt, GetModelOutputs 29from fmeval.eval_algorithms.util import ( 30 validate_dataset, 31 verify_model_determinism, 32 get_dataset_configs, 33) 34from fmeval.model_runners.composers.composers import PromptComposer 35from fmeval.model_runners.model_runner import ModelRunner 36from fmeval.constants import BERTSCORE_DEFAULT_MODEL 37from fmeval.transforms.summarization_accuracy_metrics import BertScore 38from fmeval.transforms.semantic_robustness_metrics import BertScoreDissimilarity, WER 39from fmeval.transforms.transform import Transform 40from fmeval.transforms.transform_pipeline import TransformPipeline 41from fmeval.transforms.util import create_output_key 42from fmeval.util import create_shared_resource, require, get_eval_results_path, cleanup_shared_resource 43 44logger = logging.getLogger(__name__) 45 46 47WER_SCORE = "word_error_rate" 48BERT_SCORE_DISSIMILARITY = "bertscore_dissimilarity" 49BASELINE_SUFFIX = "baseline" 50BASELINE_WER_SCORE = f"{WER_SCORE}_{BASELINE_SUFFIX}" 51BASELINE_BERT_SCORE_DISSIMILARITY = f"{BERT_SCORE_DISSIMILARITY}_{BASELINE_SUFFIX}" 52 53 54@dataclass(frozen=True) 55class GeneralSemanticRobustnessConfig(SemanticRobustnessConfig): 56 """Configures the general semantic robustness evaluation algorithm. 57 58 :param num_baseline_samples: Only used for non-deterministic models. Number of times we generate 59 the model output with the same input to compute the "baseline" change in model output. We 60 compute differences between all pairs of outputs, i.e. between comb(num_baseline_samples, 2) pairs. 61 62 :param model_type_for_bertscore: Model type to use for BERT score. 63 """ 64 65 num_baseline_samples: int = 4 66 model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL 67 68 def __post_init__(self): 69 super().__post_init__() 70 require( 71 BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore), 72 f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in " 73 f"GeneralSemanticRobustnessConfig, please choose from acceptable values: {BertscoreHelperModelTypes.model_list()}.", 74 ) 75 require( 76 self.num_baseline_samples >= 2, 77 f"Invalid num_baseline_samples: {self.num_baseline_samples} in GeneralSemanticRobustnessConfig. " 78 f"The value should be at least 2.", 79 ) 80 81 82class GeneralSemanticRobustness(EvalAlgorithmInterface): 83 """Semantic Robustness evaluation algorithm for general task LLMs. 84 85 This evaluation measures how much the model output changes as a result of semantic preserving 86 perturbations. Given the input, e.g., "A quick brown fox jumps over the lazy dog", the 87 evaluation creates a perturbation that preserves the semantic meaning of the input e.g., 88 whitespace perturbation that changes the input text to "A q uick bro wn fox ju mps overthe lazy 89 dog". The evaluation then measures how much the model output changes when prompted with the 90 original vs. perturbed input. 91 92 The output difference is measured using two metrics: the [Word Error Rate](https://huggingface.co/spaces/evaluate-metric/wer) 93 and the BERTScore Dissimilarity, which is 94 1 - [BERTScore](https://huggingface.co/spaces/evaluate-metric/bertscore), between the original 95 and the perturbed outputs. Word Error Rate measures syntactic differences, that is, changes in 96 the words, whereas BERTScore Dissimilarity measures semantic differences. Semantic differences 97 account of cases when the precise words in the output change but the meaning is the same, e.g., 98 consider the outputs "it is pouring down today" vs. "it is very rainy today". 99 100 Note: When the model generation strategy is non-deterministic (e.g., with non-zero temperature), 101 the output can change even if the input is the same. In such scenarios, reporting differences 102 (using Word Error Rate or BERTScore Dissimilarity) between the model output on the original input 103 and perturbed inputs might show artificially low robustness since the model output changes even 104 without a change in the input. So this evaluation normalizes the robustness score to account for 105 the baseline non-determinism. Specifically, if d is a score (Word Error Rate or BERTScore 106 Dissimilarity), then the evaluation reports max(0, d - d_base) where d_base measures the 107 differences between the model output on the same input. 108 """ 109 110 eval_name = EvalAlgorithm.GENERAL_SEMANTIC_ROBUSTNESS.value 111 112 def __init__( 113 self, 114 eval_algorithm_config: GeneralSemanticRobustnessConfig = GeneralSemanticRobustnessConfig(), 115 ): 116 """GeneralSemanticRobustness initializer. 117 118 :param eval_algorithm_config: General semantic robustness evaluation algorithm config. 119 """ 120 super().__init__(eval_algorithm_config) 121 self.num_perturbations = eval_algorithm_config.num_perturbations 122 self.num_baseline_samples = eval_algorithm_config.num_baseline_samples 123 self.perturbation_transform = get_perturbation_transform(eval_algorithm_config) 124 self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore) 125 126 def _build_pipeline( 127 self, 128 model: ModelRunner, 129 prompt_template: str, 130 is_deterministic: bool, 131 ) -> TransformPipeline: 132 """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`. 133 134 While other evaluation algorithms (e.g. Summarization Accuracy) can configure 135 their TransformPipeline at algorithm initialization, because the General 136 Semantic Robustness algorithm's evaluation logic depends on the ModelRunner 137 and prompt template that are evaluation-specific (i.e. these parameters aren't 138 configured at the algorithm level), the pipeline used by the GSR algorithm is built 139 when `evaluate` or `evaluate_sample` is called. 140 141 :param model: The ModelRunner representing the model under evaluation. 142 :param prompt_template: A template that is used to construct the prompt fed to the model. 143 :param is_deterministic: Whether `model` produces deterministic results. 144 In `evaluate_sample`, this is computed by invoking the model with the 145 same input twice, and checking if the model output is the same. 146 In `evaluate`, similar logic is used, but instead of using just a single input, 147 multiple inputs from the dataset are used. 148 :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`. 149 """ 150 ( 151 get_perturbed_inputs, 152 gen_perturbed_prompts, 153 get_perturbed_responses, 154 ) = get_model_outputs_from_perturbed_inputs( 155 self.perturbation_transform, 156 prompt_template, 157 model, 158 ) 159 160 original_model_output_key = DatasetColumns.MODEL_OUTPUT.value.name 161 # Compute BERTScores with target_output = the original model output 162 # and model_output = the output from invoking the model with the perturbed prompt. 163 get_bert_scores = BertScore( 164 target_output_keys=[original_model_output_key], 165 model_output_keys=get_perturbed_responses.output_keys, 166 output_keys=[create_output_key(BertScore.__name__, i) for i in range(self.num_perturbations)], 167 allow_duplicate_input_keys=True, 168 bertscore_model=self.bertscore_model, 169 ) 170 171 compute_bertscore_dissimilarity = BertScoreDissimilarity( 172 bert_score_keys=get_bert_scores.output_keys, 173 output_key=BERT_SCORE_DISSIMILARITY, 174 ) 175 176 compute_wer_metric = WER( 177 prediction_keys=get_perturbed_responses.output_keys, 178 reference_keys=[original_model_output_key for _ in range(self.num_perturbations)], 179 output_key=WER_SCORE, 180 ) 181 182 transforms = [ 183 get_perturbed_inputs, 184 gen_perturbed_prompts, 185 get_perturbed_responses, 186 get_bert_scores, 187 compute_bertscore_dissimilarity, 188 compute_wer_metric, 189 ] 190 191 pipeline = TransformPipeline(transforms) 192 193 # If the model is not deterministic, we execute additional steps 194 # to compute baseline scores for both BERTScore and WER. 195 if not is_deterministic: 196 # Invoke the model with the original (i.e. unperturbed) prompt 197 # self.num_baseline_samples - 1 times. 198 baseline_response_keys = [ 199 create_output_key(GeneratePrompt.__name__, BASELINE_SUFFIX, i) 200 for i in range(self.num_baseline_samples - 1) 201 ] 202 get_baseline_outputs = GetModelOutputs( 203 input_to_output_keys={DatasetColumns.PROMPT.value.name: baseline_response_keys}, 204 model_runner=model, 205 ) 206 207 # Get every possible pair of model outputs. 208 # The first output in the pair is treated as the target output 209 # and the second output is treated as the model output 210 # when computing the BERTScore. 211 baseline_keys = baseline_response_keys + [DatasetColumns.MODEL_OUTPUT.value.name] 212 all_pairs = itertools.combinations(baseline_keys, 2) 213 first_output_keys, second_output_keys = zip(*all_pairs) 214 215 # Compute baseline BERTScores and then compute BERTScore Dissimilarity using these BERTScores. 216 get_baseline_bert_scores = BertScore( 217 target_output_keys=list(first_output_keys), 218 model_output_keys=list(second_output_keys), 219 output_keys=[ 220 create_output_key(BertScore.__name__, BASELINE_SUFFIX, i) for i in range(len(first_output_keys)) 221 ], 222 allow_duplicate_input_keys=True, 223 bertscore_model=self.bertscore_model, 224 ) 225 compute_baseline_bertscore_dissimilarity = BertScoreDissimilarity( 226 bert_score_keys=get_baseline_bert_scores.output_keys, 227 output_key=BASELINE_BERT_SCORE_DISSIMILARITY, 228 ) 229 230 # Compute WER metric using the baseline model outputs. 231 compute_baseline_wer_metric = WER( 232 prediction_keys=list(first_output_keys), 233 reference_keys=list(second_output_keys), 234 output_key=BASELINE_WER_SCORE, 235 ) 236 # Update BERTScore Dissimilarity and WER metrics 237 # given the new baseline scores that have been computed. 238 update_scores = UpdateRobustnessScores() 239 240 # Extend the pipeline with these additional steps. 241 additional_steps = TransformPipeline( 242 [ 243 get_baseline_outputs, 244 get_baseline_bert_scores, 245 compute_baseline_bertscore_dissimilarity, 246 compute_baseline_wer_metric, 247 update_scores, 248 ] 249 ) 250 pipeline = TransformPipeline([pipeline, additional_steps]) 251 252 return pipeline 253 254 def evaluate_sample( 255 self, 256 model_input: str, 257 model: ModelRunner, 258 prompt_template: str = DEFAULT_PROMPT_TEMPLATE, 259 ) -> List[EvalScore]: # type: ignore[override] 260 """Compute general semantic robustness metrics for a single sample. 261 262 :param model_input: Text input for model. 263 :param model: An instance of ModelRunner representing the model under evaluation. 264 :param prompt_template: A template that is used in conjunction with `model_input` 265 to construct the prompt that is fed to the model. 266 :returns: A list of EvalScore objects, one for each of the robustness metrics. 267 """ 268 # Determine whether model produces deterministic outputs, as this affects 269 # what steps will be included in the TransformPipeline. 270 prompt_composer = PromptComposer(prompt_template) 271 prompt = prompt_composer.compose(model_input) 272 model_output = model.predict(prompt)[0] 273 is_deterministic = model_output == model.predict(prompt)[0] 274 275 sample = { 276 DatasetColumns.MODEL_INPUT.value.name: model_input, 277 DatasetColumns.PROMPT.value.name: prompt, 278 DatasetColumns.MODEL_OUTPUT.value.name: model_output, 279 } 280 pipeline = self._build_pipeline(model, prompt_template, is_deterministic=is_deterministic) 281 output_record = pipeline.execute_record(sample) 282 283 bert_score_dissimilarity_value = output_record[BERT_SCORE_DISSIMILARITY] 284 wer_value = output_record[WER_SCORE] 285 return [ 286 EvalScore(name=BERT_SCORE_DISSIMILARITY, value=bert_score_dissimilarity_value), 287 EvalScore(name=WER_SCORE, value=wer_value), 288 ] 289 290 def evaluate( 291 self, 292 model: ModelRunner, 293 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 294 prompt_template: Optional[str] = None, 295 num_records: int = 100, 296 save: bool = False, 297 save_strategy: Optional[SaveStrategy] = None, 298 ) -> List[EvalOutput]: 299 """Compute general semantic robustness metrics on one or more datasets. 300 301 :param model: An instance of ModelRunner representing the model under evaluation. 302 This is a required argument, as even if the dataset contains model outputs, 303 semantic robustness algorithms rely on invoking a model on perturbed inputs 304 to see how the model outputs from the perturbed inputs differ from the original 305 model outputs. 306 :param dataset_config: Configures a single dataset or list of datasets used for the 307 evaluation. If not provided, this method will run evaluations using all of its 308 supported built-in datasets. 309 :param prompt_template: A template used to generate prompts that are fed to the model. 310 If not provided, defaults will be used. 311 :param num_records: The number of records to be sampled randomly from the input dataset 312 used to perform the evaluation. 313 :param save: If set to true, prompt responses and scores will be saved to a file. 314 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 315 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 316 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 317 318 :return: A list of EvalOutput objects. 319 """ 320 # Create a shared resource to be used during the evaluation. 321 bertscore_shared_resource = create_shared_resource(self.bertscore_model) 322 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 323 eval_outputs = [] 324 for dataset_config in dataset_configs: 325 dataset = get_dataset(dataset_config, num_records) 326 validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name]) 327 dataset_prompt_template = ( 328 get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template 329 ) 330 is_deterministic = verify_model_determinism(model, dataset, dataset_prompt_template) 331 eval_output = evaluate_dataset( 332 dataset=dataset, 333 pipeline=self._build_pipeline(model, dataset_prompt_template, is_deterministic=is_deterministic), 334 dataset_name=dataset_config.dataset_name, 335 eval_name=self.eval_name, 336 metric_names=[BERT_SCORE_DISSIMILARITY, WER_SCORE], 337 eval_results_path=get_eval_results_path(), 338 model=model, 339 prompt_template=dataset_prompt_template, 340 agg_method=MEAN, 341 save=save, 342 save_strategy=save_strategy, 343 ) 344 eval_outputs.append(eval_output) 345 346 cleanup_shared_resource(bertscore_shared_resource) 347 return eval_outputs 348 349 350class UpdateRobustnessScores(Transform): 351 """Used by General Semantic Robustness when the model under evaluation is not deterministic. 352 353 See the class documentation for GeneralSemanticRobustness for details on how baseline scores 354 are computed and used. This transform simply updates the data corresponding to the 355 WER_SCORE and BERT_SCORE_DISSIMILARITY keys after baseline scores have been computed. 356 """ 357 358 def __init__(self): 359 super().__init__() 360 self.register_input_output_keys( 361 input_keys=[WER_SCORE, BERT_SCORE_DISSIMILARITY, BASELINE_WER_SCORE, BASELINE_BERT_SCORE_DISSIMILARITY], 362 output_keys=[WER_SCORE, BERT_SCORE_DISSIMILARITY], 363 ) 364 365 def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]: 366 """Update the values corresponding to the keys WER_SCORE and BERT_SCORE_DISSIMILARITY. 367 368 This method does not add new keys, but rather mutates the data corresponding to existing 369 keys (WER_SCORE and BERT_SCORE_DISSIMILARITY) in the input record. 370 371 :param record: The input record. 372 :returns: The input record with updated WER_SCORE and BERT_SCORE_DISSIMILARITY values. 373 """ 374 bert_score_dissimilarity_value = record[BERT_SCORE_DISSIMILARITY] 375 wer_value = record[WER_SCORE] 376 baseline_bert_score_dissimilarity_value = record[BASELINE_BERT_SCORE_DISSIMILARITY] 377 baseline_wer_value = record[BASELINE_WER_SCORE] 378 379 record[BERT_SCORE_DISSIMILARITY] = max( 380 0, bert_score_dissimilarity_value - baseline_bert_score_dissimilarity_value 381 ) 382 record[WER_SCORE] = max(0, wer_value - baseline_wer_value) 383 return record
55@dataclass(frozen=True) 56class GeneralSemanticRobustnessConfig(SemanticRobustnessConfig): 57 """Configures the general semantic robustness evaluation algorithm. 58 59 :param num_baseline_samples: Only used for non-deterministic models. Number of times we generate 60 the model output with the same input to compute the "baseline" change in model output. We 61 compute differences between all pairs of outputs, i.e. between comb(num_baseline_samples, 2) pairs. 62 63 :param model_type_for_bertscore: Model type to use for BERT score. 64 """ 65 66 num_baseline_samples: int = 4 67 model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL 68 69 def __post_init__(self): 70 super().__post_init__() 71 require( 72 BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore), 73 f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in " 74 f"GeneralSemanticRobustnessConfig, please choose from acceptable values: {BertscoreHelperModelTypes.model_list()}.", 75 ) 76 require( 77 self.num_baseline_samples >= 2, 78 f"Invalid num_baseline_samples: {self.num_baseline_samples} in GeneralSemanticRobustnessConfig. " 79 f"The value should be at least 2.", 80 )
Configures the general semantic robustness evaluation algorithm.
Parameters
num_baseline_samples: Only used for non-deterministic models. Number of times we generate the model output with the same input to compute the "baseline" change in model output. We compute differences between all pairs of outputs, i.e. between comb(num_baseline_samples, 2) pairs.
model_type_for_bertscore: Model type to use for BERT score.
83class GeneralSemanticRobustness(EvalAlgorithmInterface): 84 """Semantic Robustness evaluation algorithm for general task LLMs. 85 86 This evaluation measures how much the model output changes as a result of semantic preserving 87 perturbations. Given the input, e.g., "A quick brown fox jumps over the lazy dog", the 88 evaluation creates a perturbation that preserves the semantic meaning of the input e.g., 89 whitespace perturbation that changes the input text to "A q uick bro wn fox ju mps overthe lazy 90 dog". The evaluation then measures how much the model output changes when prompted with the 91 original vs. perturbed input. 92 93 The output difference is measured using two metrics: the [Word Error Rate](https://huggingface.co/spaces/evaluate-metric/wer) 94 and the BERTScore Dissimilarity, which is 95 1 - [BERTScore](https://huggingface.co/spaces/evaluate-metric/bertscore), between the original 96 and the perturbed outputs. Word Error Rate measures syntactic differences, that is, changes in 97 the words, whereas BERTScore Dissimilarity measures semantic differences. Semantic differences 98 account of cases when the precise words in the output change but the meaning is the same, e.g., 99 consider the outputs "it is pouring down today" vs. "it is very rainy today". 100 101 Note: When the model generation strategy is non-deterministic (e.g., with non-zero temperature), 102 the output can change even if the input is the same. In such scenarios, reporting differences 103 (using Word Error Rate or BERTScore Dissimilarity) between the model output on the original input 104 and perturbed inputs might show artificially low robustness since the model output changes even 105 without a change in the input. So this evaluation normalizes the robustness score to account for 106 the baseline non-determinism. Specifically, if d is a score (Word Error Rate or BERTScore 107 Dissimilarity), then the evaluation reports max(0, d - d_base) where d_base measures the 108 differences between the model output on the same input. 109 """ 110 111 eval_name = EvalAlgorithm.GENERAL_SEMANTIC_ROBUSTNESS.value 112 113 def __init__( 114 self, 115 eval_algorithm_config: GeneralSemanticRobustnessConfig = GeneralSemanticRobustnessConfig(), 116 ): 117 """GeneralSemanticRobustness initializer. 118 119 :param eval_algorithm_config: General semantic robustness evaluation algorithm config. 120 """ 121 super().__init__(eval_algorithm_config) 122 self.num_perturbations = eval_algorithm_config.num_perturbations 123 self.num_baseline_samples = eval_algorithm_config.num_baseline_samples 124 self.perturbation_transform = get_perturbation_transform(eval_algorithm_config) 125 self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore) 126 127 def _build_pipeline( 128 self, 129 model: ModelRunner, 130 prompt_template: str, 131 is_deterministic: bool, 132 ) -> TransformPipeline: 133 """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`. 134 135 While other evaluation algorithms (e.g. Summarization Accuracy) can configure 136 their TransformPipeline at algorithm initialization, because the General 137 Semantic Robustness algorithm's evaluation logic depends on the ModelRunner 138 and prompt template that are evaluation-specific (i.e. these parameters aren't 139 configured at the algorithm level), the pipeline used by the GSR algorithm is built 140 when `evaluate` or `evaluate_sample` is called. 141 142 :param model: The ModelRunner representing the model under evaluation. 143 :param prompt_template: A template that is used to construct the prompt fed to the model. 144 :param is_deterministic: Whether `model` produces deterministic results. 145 In `evaluate_sample`, this is computed by invoking the model with the 146 same input twice, and checking if the model output is the same. 147 In `evaluate`, similar logic is used, but instead of using just a single input, 148 multiple inputs from the dataset are used. 149 :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`. 150 """ 151 ( 152 get_perturbed_inputs, 153 gen_perturbed_prompts, 154 get_perturbed_responses, 155 ) = get_model_outputs_from_perturbed_inputs( 156 self.perturbation_transform, 157 prompt_template, 158 model, 159 ) 160 161 original_model_output_key = DatasetColumns.MODEL_OUTPUT.value.name 162 # Compute BERTScores with target_output = the original model output 163 # and model_output = the output from invoking the model with the perturbed prompt. 164 get_bert_scores = BertScore( 165 target_output_keys=[original_model_output_key], 166 model_output_keys=get_perturbed_responses.output_keys, 167 output_keys=[create_output_key(BertScore.__name__, i) for i in range(self.num_perturbations)], 168 allow_duplicate_input_keys=True, 169 bertscore_model=self.bertscore_model, 170 ) 171 172 compute_bertscore_dissimilarity = BertScoreDissimilarity( 173 bert_score_keys=get_bert_scores.output_keys, 174 output_key=BERT_SCORE_DISSIMILARITY, 175 ) 176 177 compute_wer_metric = WER( 178 prediction_keys=get_perturbed_responses.output_keys, 179 reference_keys=[original_model_output_key for _ in range(self.num_perturbations)], 180 output_key=WER_SCORE, 181 ) 182 183 transforms = [ 184 get_perturbed_inputs, 185 gen_perturbed_prompts, 186 get_perturbed_responses, 187 get_bert_scores, 188 compute_bertscore_dissimilarity, 189 compute_wer_metric, 190 ] 191 192 pipeline = TransformPipeline(transforms) 193 194 # If the model is not deterministic, we execute additional steps 195 # to compute baseline scores for both BERTScore and WER. 196 if not is_deterministic: 197 # Invoke the model with the original (i.e. unperturbed) prompt 198 # self.num_baseline_samples - 1 times. 199 baseline_response_keys = [ 200 create_output_key(GeneratePrompt.__name__, BASELINE_SUFFIX, i) 201 for i in range(self.num_baseline_samples - 1) 202 ] 203 get_baseline_outputs = GetModelOutputs( 204 input_to_output_keys={DatasetColumns.PROMPT.value.name: baseline_response_keys}, 205 model_runner=model, 206 ) 207 208 # Get every possible pair of model outputs. 209 # The first output in the pair is treated as the target output 210 # and the second output is treated as the model output 211 # when computing the BERTScore. 212 baseline_keys = baseline_response_keys + [DatasetColumns.MODEL_OUTPUT.value.name] 213 all_pairs = itertools.combinations(baseline_keys, 2) 214 first_output_keys, second_output_keys = zip(*all_pairs) 215 216 # Compute baseline BERTScores and then compute BERTScore Dissimilarity using these BERTScores. 217 get_baseline_bert_scores = BertScore( 218 target_output_keys=list(first_output_keys), 219 model_output_keys=list(second_output_keys), 220 output_keys=[ 221 create_output_key(BertScore.__name__, BASELINE_SUFFIX, i) for i in range(len(first_output_keys)) 222 ], 223 allow_duplicate_input_keys=True, 224 bertscore_model=self.bertscore_model, 225 ) 226 compute_baseline_bertscore_dissimilarity = BertScoreDissimilarity( 227 bert_score_keys=get_baseline_bert_scores.output_keys, 228 output_key=BASELINE_BERT_SCORE_DISSIMILARITY, 229 ) 230 231 # Compute WER metric using the baseline model outputs. 232 compute_baseline_wer_metric = WER( 233 prediction_keys=list(first_output_keys), 234 reference_keys=list(second_output_keys), 235 output_key=BASELINE_WER_SCORE, 236 ) 237 # Update BERTScore Dissimilarity and WER metrics 238 # given the new baseline scores that have been computed. 239 update_scores = UpdateRobustnessScores() 240 241 # Extend the pipeline with these additional steps. 242 additional_steps = TransformPipeline( 243 [ 244 get_baseline_outputs, 245 get_baseline_bert_scores, 246 compute_baseline_bertscore_dissimilarity, 247 compute_baseline_wer_metric, 248 update_scores, 249 ] 250 ) 251 pipeline = TransformPipeline([pipeline, additional_steps]) 252 253 return pipeline 254 255 def evaluate_sample( 256 self, 257 model_input: str, 258 model: ModelRunner, 259 prompt_template: str = DEFAULT_PROMPT_TEMPLATE, 260 ) -> List[EvalScore]: # type: ignore[override] 261 """Compute general semantic robustness metrics for a single sample. 262 263 :param model_input: Text input for model. 264 :param model: An instance of ModelRunner representing the model under evaluation. 265 :param prompt_template: A template that is used in conjunction with `model_input` 266 to construct the prompt that is fed to the model. 267 :returns: A list of EvalScore objects, one for each of the robustness metrics. 268 """ 269 # Determine whether model produces deterministic outputs, as this affects 270 # what steps will be included in the TransformPipeline. 271 prompt_composer = PromptComposer(prompt_template) 272 prompt = prompt_composer.compose(model_input) 273 model_output = model.predict(prompt)[0] 274 is_deterministic = model_output == model.predict(prompt)[0] 275 276 sample = { 277 DatasetColumns.MODEL_INPUT.value.name: model_input, 278 DatasetColumns.PROMPT.value.name: prompt, 279 DatasetColumns.MODEL_OUTPUT.value.name: model_output, 280 } 281 pipeline = self._build_pipeline(model, prompt_template, is_deterministic=is_deterministic) 282 output_record = pipeline.execute_record(sample) 283 284 bert_score_dissimilarity_value = output_record[BERT_SCORE_DISSIMILARITY] 285 wer_value = output_record[WER_SCORE] 286 return [ 287 EvalScore(name=BERT_SCORE_DISSIMILARITY, value=bert_score_dissimilarity_value), 288 EvalScore(name=WER_SCORE, value=wer_value), 289 ] 290 291 def evaluate( 292 self, 293 model: ModelRunner, 294 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 295 prompt_template: Optional[str] = None, 296 num_records: int = 100, 297 save: bool = False, 298 save_strategy: Optional[SaveStrategy] = None, 299 ) -> List[EvalOutput]: 300 """Compute general semantic robustness metrics on one or more datasets. 301 302 :param model: An instance of ModelRunner representing the model under evaluation. 303 This is a required argument, as even if the dataset contains model outputs, 304 semantic robustness algorithms rely on invoking a model on perturbed inputs 305 to see how the model outputs from the perturbed inputs differ from the original 306 model outputs. 307 :param dataset_config: Configures a single dataset or list of datasets used for the 308 evaluation. If not provided, this method will run evaluations using all of its 309 supported built-in datasets. 310 :param prompt_template: A template used to generate prompts that are fed to the model. 311 If not provided, defaults will be used. 312 :param num_records: The number of records to be sampled randomly from the input dataset 313 used to perform the evaluation. 314 :param save: If set to true, prompt responses and scores will be saved to a file. 315 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 316 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 317 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 318 319 :return: A list of EvalOutput objects. 320 """ 321 # Create a shared resource to be used during the evaluation. 322 bertscore_shared_resource = create_shared_resource(self.bertscore_model) 323 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 324 eval_outputs = [] 325 for dataset_config in dataset_configs: 326 dataset = get_dataset(dataset_config, num_records) 327 validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name]) 328 dataset_prompt_template = ( 329 get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template 330 ) 331 is_deterministic = verify_model_determinism(model, dataset, dataset_prompt_template) 332 eval_output = evaluate_dataset( 333 dataset=dataset, 334 pipeline=self._build_pipeline(model, dataset_prompt_template, is_deterministic=is_deterministic), 335 dataset_name=dataset_config.dataset_name, 336 eval_name=self.eval_name, 337 metric_names=[BERT_SCORE_DISSIMILARITY, WER_SCORE], 338 eval_results_path=get_eval_results_path(), 339 model=model, 340 prompt_template=dataset_prompt_template, 341 agg_method=MEAN, 342 save=save, 343 save_strategy=save_strategy, 344 ) 345 eval_outputs.append(eval_output) 346 347 cleanup_shared_resource(bertscore_shared_resource) 348 return eval_outputs
Semantic Robustness evaluation algorithm for general task LLMs.
This evaluation measures how much the model output changes as a result of semantic preserving perturbations. Given the input, e.g., "A quick brown fox jumps over the lazy dog", the evaluation creates a perturbation that preserves the semantic meaning of the input e.g., whitespace perturbation that changes the input text to "A q uick bro wn fox ju mps overthe lazy dog". The evaluation then measures how much the model output changes when prompted with the original vs. perturbed input.
The output difference is measured using two metrics: the Word Error Rate and the BERTScore Dissimilarity, which is 1 - BERTScore, between the original and the perturbed outputs. Word Error Rate measures syntactic differences, that is, changes in the words, whereas BERTScore Dissimilarity measures semantic differences. Semantic differences account of cases when the precise words in the output change but the meaning is the same, e.g., consider the outputs "it is pouring down today" vs. "it is very rainy today".
Note: When the model generation strategy is non-deterministic (e.g., with non-zero temperature), the output can change even if the input is the same. In such scenarios, reporting differences (using Word Error Rate or BERTScore Dissimilarity) between the model output on the original input and perturbed inputs might show artificially low robustness since the model output changes even without a change in the input. So this evaluation normalizes the robustness score to account for the baseline non-determinism. Specifically, if d is a score (Word Error Rate or BERTScore Dissimilarity), then the evaluation reports max(0, d - d_base) where d_base measures the differences between the model output on the same input.
113 def __init__( 114 self, 115 eval_algorithm_config: GeneralSemanticRobustnessConfig = GeneralSemanticRobustnessConfig(), 116 ): 117 """GeneralSemanticRobustness initializer. 118 119 :param eval_algorithm_config: General semantic robustness evaluation algorithm config. 120 """ 121 super().__init__(eval_algorithm_config) 122 self.num_perturbations = eval_algorithm_config.num_perturbations 123 self.num_baseline_samples = eval_algorithm_config.num_baseline_samples 124 self.perturbation_transform = get_perturbation_transform(eval_algorithm_config) 125 self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
GeneralSemanticRobustness initializer.
Parameters
- eval_algorithm_config: General semantic robustness evaluation algorithm config.
255 def evaluate_sample( 256 self, 257 model_input: str, 258 model: ModelRunner, 259 prompt_template: str = DEFAULT_PROMPT_TEMPLATE, 260 ) -> List[EvalScore]: # type: ignore[override] 261 """Compute general semantic robustness metrics for a single sample. 262 263 :param model_input: Text input for model. 264 :param model: An instance of ModelRunner representing the model under evaluation. 265 :param prompt_template: A template that is used in conjunction with `model_input` 266 to construct the prompt that is fed to the model. 267 :returns: A list of EvalScore objects, one for each of the robustness metrics. 268 """ 269 # Determine whether model produces deterministic outputs, as this affects 270 # what steps will be included in the TransformPipeline. 271 prompt_composer = PromptComposer(prompt_template) 272 prompt = prompt_composer.compose(model_input) 273 model_output = model.predict(prompt)[0] 274 is_deterministic = model_output == model.predict(prompt)[0] 275 276 sample = { 277 DatasetColumns.MODEL_INPUT.value.name: model_input, 278 DatasetColumns.PROMPT.value.name: prompt, 279 DatasetColumns.MODEL_OUTPUT.value.name: model_output, 280 } 281 pipeline = self._build_pipeline(model, prompt_template, is_deterministic=is_deterministic) 282 output_record = pipeline.execute_record(sample) 283 284 bert_score_dissimilarity_value = output_record[BERT_SCORE_DISSIMILARITY] 285 wer_value = output_record[WER_SCORE] 286 return [ 287 EvalScore(name=BERT_SCORE_DISSIMILARITY, value=bert_score_dissimilarity_value), 288 EvalScore(name=WER_SCORE, value=wer_value), 289 ]
Compute general semantic robustness metrics for a single sample.
Parameters
- model_input: Text input for model.
- model: An instance of ModelRunner representing the model under evaluation.
- prompt_template: A template that is used in conjunction with
model_input
to construct the prompt that is fed to the model. :returns: A list of EvalScore objects, one for each of the robustness metrics.
291 def evaluate( 292 self, 293 model: ModelRunner, 294 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 295 prompt_template: Optional[str] = None, 296 num_records: int = 100, 297 save: bool = False, 298 save_strategy: Optional[SaveStrategy] = None, 299 ) -> List[EvalOutput]: 300 """Compute general semantic robustness metrics on one or more datasets. 301 302 :param model: An instance of ModelRunner representing the model under evaluation. 303 This is a required argument, as even if the dataset contains model outputs, 304 semantic robustness algorithms rely on invoking a model on perturbed inputs 305 to see how the model outputs from the perturbed inputs differ from the original 306 model outputs. 307 :param dataset_config: Configures a single dataset or list of datasets used for the 308 evaluation. If not provided, this method will run evaluations using all of its 309 supported built-in datasets. 310 :param prompt_template: A template used to generate prompts that are fed to the model. 311 If not provided, defaults will be used. 312 :param num_records: The number of records to be sampled randomly from the input dataset 313 used to perform the evaluation. 314 :param save: If set to true, prompt responses and scores will be saved to a file. 315 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 316 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 317 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 318 319 :return: A list of EvalOutput objects. 320 """ 321 # Create a shared resource to be used during the evaluation. 322 bertscore_shared_resource = create_shared_resource(self.bertscore_model) 323 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 324 eval_outputs = [] 325 for dataset_config in dataset_configs: 326 dataset = get_dataset(dataset_config, num_records) 327 validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name]) 328 dataset_prompt_template = ( 329 get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template 330 ) 331 is_deterministic = verify_model_determinism(model, dataset, dataset_prompt_template) 332 eval_output = evaluate_dataset( 333 dataset=dataset, 334 pipeline=self._build_pipeline(model, dataset_prompt_template, is_deterministic=is_deterministic), 335 dataset_name=dataset_config.dataset_name, 336 eval_name=self.eval_name, 337 metric_names=[BERT_SCORE_DISSIMILARITY, WER_SCORE], 338 eval_results_path=get_eval_results_path(), 339 model=model, 340 prompt_template=dataset_prompt_template, 341 agg_method=MEAN, 342 save=save, 343 save_strategy=save_strategy, 344 ) 345 eval_outputs.append(eval_output) 346 347 cleanup_shared_resource(bertscore_shared_resource) 348 return eval_outputs
Compute general semantic robustness metrics on one or more datasets.
Parameters
- model: An instance of ModelRunner representing the model under evaluation. This is a required argument, as even if the dataset contains model outputs, semantic robustness algorithms rely on invoking a model on perturbed inputs to see how the model outputs from the perturbed inputs differ from the original model outputs.
- dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
- prompt_template: A template used to generate prompts that are fed to the model. If not provided, defaults will be used.
- num_records: The number of records to be sampled randomly from the input dataset used to perform the evaluation.
- save: If set to true, prompt responses and scores will be saved to a file.
- save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
If that environment variable is also not configured, it will be saved to the default path
/tmp/eval_results/
.
Returns
A list of EvalOutput objects.
351class UpdateRobustnessScores(Transform): 352 """Used by General Semantic Robustness when the model under evaluation is not deterministic. 353 354 See the class documentation for GeneralSemanticRobustness for details on how baseline scores 355 are computed and used. This transform simply updates the data corresponding to the 356 WER_SCORE and BERT_SCORE_DISSIMILARITY keys after baseline scores have been computed. 357 """ 358 359 def __init__(self): 360 super().__init__() 361 self.register_input_output_keys( 362 input_keys=[WER_SCORE, BERT_SCORE_DISSIMILARITY, BASELINE_WER_SCORE, BASELINE_BERT_SCORE_DISSIMILARITY], 363 output_keys=[WER_SCORE, BERT_SCORE_DISSIMILARITY], 364 ) 365 366 def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]: 367 """Update the values corresponding to the keys WER_SCORE and BERT_SCORE_DISSIMILARITY. 368 369 This method does not add new keys, but rather mutates the data corresponding to existing 370 keys (WER_SCORE and BERT_SCORE_DISSIMILARITY) in the input record. 371 372 :param record: The input record. 373 :returns: The input record with updated WER_SCORE and BERT_SCORE_DISSIMILARITY values. 374 """ 375 bert_score_dissimilarity_value = record[BERT_SCORE_DISSIMILARITY] 376 wer_value = record[WER_SCORE] 377 baseline_bert_score_dissimilarity_value = record[BASELINE_BERT_SCORE_DISSIMILARITY] 378 baseline_wer_value = record[BASELINE_WER_SCORE] 379 380 record[BERT_SCORE_DISSIMILARITY] = max( 381 0, bert_score_dissimilarity_value - baseline_bert_score_dissimilarity_value 382 ) 383 record[WER_SCORE] = max(0, wer_value - baseline_wer_value) 384 return record
Used by General Semantic Robustness when the model under evaluation is not deterministic.
See the class documentation for GeneralSemanticRobustness for details on how baseline scores are computed and used. This transform simply updates the data corresponding to the WER_SCORE and BERT_SCORE_DISSIMILARITY keys after baseline scores have been computed.
359 def __init__(self): 360 super().__init__() 361 self.register_input_output_keys( 362 input_keys=[WER_SCORE, BERT_SCORE_DISSIMILARITY, BASELINE_WER_SCORE, BASELINE_BERT_SCORE_DISSIMILARITY], 363 output_keys=[WER_SCORE, BERT_SCORE_DISSIMILARITY], 364 )
Transform initializer.
Concrete subclasses of Transform should always call super().__init__
with every argument passed to their own __init__ method.
Transform.__init__ stores all positional arguments in the args
instance
attribute and all keyword arguments in the kwargs
instance attribute.
This data is passed to Ray when Ray creates copies of this Transform instance
to perform parallel execution.
Note: The input_keys
and output_keys
attributes are initialized to None
and only assigned a meaningful value if the register_input_output_keys
method
is called. This method is used in conjunction with the validate_call
decorator
to perform validations of the __call__ inputs and outputs at runtime.
While it is not strictly necessary to utilize register_input_output_keys
and
validate_call
when implementing your own transforms, these methods are used in
all built-in transforms.
Parameters
- *args: Variable length argument list.
- **kwargs: Arbitrary keyword arguments.