fmeval.eval_algorithms.classification_accuracy_semantic_robustness
1import logging 2import warnings 3 4from typing import Callable, List, Optional, Union 5from dataclasses import dataclass 6 7from fmeval.constants import ( 8 DatasetColumns, 9 MEAN, 10) 11from fmeval.data_loaders.util import get_dataset 12from fmeval.data_loaders.data_config import DataConfig 13from fmeval.eval_algorithms.common import evaluate_dataset 14from fmeval.eval_algorithms.save_strategy import SaveStrategy 15from fmeval.eval_algorithms.semantic_robustness_utils import ( 16 SemanticRobustnessConfig, 17 get_perturbation_transform, 18 get_model_outputs_from_perturbed_inputs, 19) 20from fmeval.eval_algorithms.util import ( 21 get_dataset_configs, 22 validate_dataset, 23 create_model_invocation_pipeline, 24) 25from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface 26from fmeval.eval_algorithms import ( 27 EvalAlgorithm, 28 EvalOutput, 29 EvalScore, 30 get_default_prompt_template, 31 DEFAULT_PROMPT_TEMPLATE, 32) 33from fmeval.model_runners.model_runner import ModelRunner 34from fmeval.eval_algorithms.classification_accuracy import ( 35 convert_model_output_to_label, 36 CLASSIFICATION_ACCURACY_SCORE, 37 UNIQUENESS_FACTOR, 38 ClassificationAccuracyScores, 39 CLASSIFIED_MODEL_OUTPUT_COLUMN_NAME, 40) 41from fmeval.transforms.semantic_robustness_metrics import MeanDeltaScores 42from fmeval.transforms.transform_pipeline import TransformPipeline 43from fmeval.util import get_eval_results_path 44 45PREFIX_FOR_DELTA_SCORES = "delta_" 46DELTA_CLASSIFICATION_ACCURACY_SCORE = PREFIX_FOR_DELTA_SCORES + CLASSIFICATION_ACCURACY_SCORE 47 48logger = logging.getLogger(__name__) 49 50 51@dataclass(frozen=True) 52class ClassificationAccuracySemanticRobustnessConfig(SemanticRobustnessConfig): 53 """Configures the Classification Accuracy Semantic Robustness evaluation algorithm. 54 55 See SemanticRobustnessConfig for the configurable parameters that this config class inherits. 56 57 :param valid_labels: A list of valid labels. 58 :param converter_fn: Function to process model output to labels. Defaults to simple integer conversion. 59 """ 60 61 valid_labels: Optional[List[str]] = None 62 converter_fn: Callable[[str, List[str]], str] = convert_model_output_to_label 63 64 def __post_init__(self): 65 super().__post_init__() 66 if self.valid_labels: 67 for i, label in enumerate(self.valid_labels): 68 if not isinstance(label, str): 69 warnings.warn("Valid labels should be strings, casting.") 70 self.valid_labels[i] = str(label) 71 72 73class ClassificationAccuracySemanticRobustness(EvalAlgorithmInterface): 74 """Semantic Robustness evaluation algorithm for Classification Accuracy 75 76 This evaluation measures how much Classification Accuracy changes as a result of semantic preserving 77 perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text, 78 how much does this impact the ability of the model to correctly classify this text. 79 80 The output difference is measured by computing the Classification Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores 81 on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$ 82 where $s$ is the score produced by the original metric (i.e., accuracy, precision, recall and balanced accuracy), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied. 83 84 For details on the Classification Accuracy metrics, see the Classification Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation. 85 """ 86 87 eval_name = EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS.value 88 89 def __init__( 90 self, 91 eval_algorithm_config: ClassificationAccuracySemanticRobustnessConfig = ClassificationAccuracySemanticRobustnessConfig(), 92 ): 93 """ClassificationAccuracySemanticRobustness initializer. 94 95 :param eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config. 96 """ 97 super().__init__(eval_algorithm_config) 98 self.config = eval_algorithm_config 99 self.perturbation_transform = get_perturbation_transform(eval_algorithm_config) 100 self.valid_labels = eval_algorithm_config.valid_labels 101 self.converter_fn = eval_algorithm_config.converter_fn 102 103 def _build_pipeline( 104 self, 105 model: ModelRunner, 106 prompt_template: str, 107 valid_labels: Optional[List[str]], 108 ) -> TransformPipeline: 109 """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`. 110 111 While other evaluation algorithms (ex: Classification Accuracy) can configure 112 their TransformPipeline at algorithm initialization, because the Classification Accuracy 113 Semantic Robustness algorithm's evaluation logic depends on the ModelRunner 114 and prompt template that are evaluation-specific (i.e. these parameters aren't 115 configured at the algorithm level), the pipeline used by this algorithm is built 116 when `evaluate` or `evaluate_sample` is called. 117 118 :param model: The ModelRunner representing the model under evaluation. 119 :param prompt_template: A template that is used to construct the prompt fed to the model. 120 :param valid_labels: A list of valid labels for the classified model output. 121 :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`. 122 """ 123 get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_outputs = get_model_outputs_from_perturbed_inputs( 124 self.perturbation_transform, 125 prompt_template, 126 model, 127 ) 128 129 original_scores = ClassificationAccuracyScores(valid_labels=valid_labels, converter_fn=self.converter_fn) 130 perturbed_scores = [ 131 ClassificationAccuracyScores( 132 valid_labels=valid_labels, 133 model_output_key=perturbed_output_key, 134 classified_model_output_key=f"{CLASSIFIED_MODEL_OUTPUT_COLUMN_NAME}_perturbed_{i}", 135 classification_accuracy_score_key=f"{CLASSIFICATION_ACCURACY_SCORE}_perturbed_{i}", 136 converter_fn=self.converter_fn, 137 ) 138 for i, perturbed_output_key in enumerate(get_perturbed_outputs.output_keys) 139 ] 140 141 perturbed_score_keys = [ 142 perturbed_score_transform.classification_accuracy_score_key 143 for perturbed_score_transform in perturbed_scores 144 ] 145 mean_delta_scores = MeanDeltaScores( 146 {CLASSIFICATION_ACCURACY_SCORE: (perturbed_score_keys, DELTA_CLASSIFICATION_ACCURACY_SCORE)} 147 ) 148 149 transforms = [ 150 get_perturbed_inputs, 151 gen_perturbed_prompts, 152 get_perturbed_outputs, 153 original_scores, 154 TransformPipeline(perturbed_scores), 155 mean_delta_scores, 156 ] 157 pipeline = TransformPipeline(transforms) 158 return pipeline 159 160 def evaluate_sample( 161 self, 162 model_input: str, 163 target_output: str, 164 model: ModelRunner, 165 prompt_template: str = DEFAULT_PROMPT_TEMPLATE, 166 ) -> List[EvalScore]: 167 """Compute classification accuracy semantic robustness metrics for a single sample. 168 169 A sample is defined as a model input and target output pair. 170 171 :param model_input: Text input, which will be composed into a prompt that gets fed to the model. 172 :param target_output: The expected response from the model. 173 :param model: An instance of ModelRunner representing the model under evaluation. 174 :param prompt_template: A template used to compose the prompt from `model_input`. 175 :return: A list of EvalScores. 176 """ 177 sample = { 178 DatasetColumns.MODEL_INPUT.value.name: model_input, 179 DatasetColumns.TARGET_OUTPUT.value.name: target_output, 180 } 181 invoke_model = create_model_invocation_pipeline(model, prompt_template) 182 compute_metrics = self._build_pipeline(model, prompt_template, self.valid_labels) 183 pipeline = TransformPipeline([invoke_model, compute_metrics]) 184 output_record = pipeline.execute_record(sample) 185 186 original_score = EvalScore( 187 name=CLASSIFICATION_ACCURACY_SCORE, value=output_record[CLASSIFICATION_ACCURACY_SCORE] 188 ) 189 delta_score = EvalScore( 190 name=DELTA_CLASSIFICATION_ACCURACY_SCORE, value=output_record[DELTA_CLASSIFICATION_ACCURACY_SCORE] 191 ) 192 return [original_score, delta_score] 193 194 def evaluate( 195 self, 196 model: ModelRunner, 197 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 198 prompt_template: Optional[str] = None, 199 num_records: int = 100, 200 save: bool = False, 201 save_strategy: Optional[SaveStrategy] = None, 202 ) -> List[EvalOutput]: 203 """Compute classification accuracy semantic robustness metrics on one or more datasets. 204 205 :param model: An instance of ModelRunner representing the model under evaluation. 206 This is a required argument, as even if the dataset contains model outputs, 207 semantic robustness algorithms rely on invoking a model on perturbed inputs 208 to see how the model outputs from the perturbed inputs differ from the original 209 model outputs. 210 :param dataset_config: Configures a single dataset or list of datasets used for the 211 evaluation. If not provided, this method will run evaluations using all of its 212 supported built-in datasets. 213 :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults 214 will be used. 215 :param num_records: The number of records to be sampled randomly from the input dataset to perform the 216 evaluation 217 :param save: If set to true, prompt responses and scores will be saved to a file. 218 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 219 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 220 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 221 222 :returns: A List of EvalOutput objects. 223 """ 224 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 225 eval_outputs: List[EvalOutput] = [] 226 227 for dataset_config in dataset_configs: 228 dataset_prompt_template = ( 229 get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template 230 ) 231 dataset = get_dataset(dataset_config, num_records) 232 validate_dataset(dataset, [DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name]) 233 234 valid_labels = ( 235 self.valid_labels 236 if self.valid_labels 237 else dataset.unique(column=DatasetColumns.TARGET_OUTPUT.value.name) 238 ) 239 row_count = dataset.count() 240 if len(valid_labels) / (row_count + 1) < UNIQUENESS_FACTOR: # pragma: no cover 241 logger.warning( 242 f"The number of classes: {len(valid_labels)} in the dataset is too large " 243 f"for the number of rows in the dataset: {row_count}", 244 ) 245 246 eval_output = evaluate_dataset( 247 dataset=dataset, 248 pipeline=self._build_pipeline(model, dataset_prompt_template, valid_labels), 249 dataset_name=dataset_config.dataset_name, 250 eval_name=self.eval_name, 251 metric_names=[CLASSIFICATION_ACCURACY_SCORE, DELTA_CLASSIFICATION_ACCURACY_SCORE], 252 eval_results_path=get_eval_results_path(), 253 model=model, 254 prompt_template=dataset_prompt_template, 255 agg_method=MEAN, 256 save=save, 257 save_strategy=save_strategy if save_strategy else None, 258 ) 259 eval_outputs.append(eval_output) 260 261 return eval_outputs
52@dataclass(frozen=True) 53class ClassificationAccuracySemanticRobustnessConfig(SemanticRobustnessConfig): 54 """Configures the Classification Accuracy Semantic Robustness evaluation algorithm. 55 56 See SemanticRobustnessConfig for the configurable parameters that this config class inherits. 57 58 :param valid_labels: A list of valid labels. 59 :param converter_fn: Function to process model output to labels. Defaults to simple integer conversion. 60 """ 61 62 valid_labels: Optional[List[str]] = None 63 converter_fn: Callable[[str, List[str]], str] = convert_model_output_to_label 64 65 def __post_init__(self): 66 super().__post_init__() 67 if self.valid_labels: 68 for i, label in enumerate(self.valid_labels): 69 if not isinstance(label, str): 70 warnings.warn("Valid labels should be strings, casting.") 71 self.valid_labels[i] = str(label)
Configures the Classification Accuracy Semantic Robustness evaluation algorithm.
See SemanticRobustnessConfig for the configurable parameters that this config class inherits.
Parameters
- valid_labels: A list of valid labels.
- converter_fn: Function to process model output to labels. Defaults to simple integer conversion.
60def convert_model_output_to_label(model_output: str, valid_labels: List[str]) -> str: 61 """Convert model output to string class label. The model is expected to return a label directly (if it has a 62 classification head), or a string containing a label (if it has a language modelling head). In the latter case we 63 strip any additional text (e.g. "The answer is 2." --> "2"). If no valid labels is contained in the 64 `model_output` an "unknown" label is returned. Users can define other `converter_fn`s, e.g. to translate a text 65 label to string ("NEGATIVE" --> "0"). 66 67 :param model_output: Value returned by the model. 68 :param valid_labels: Valid labels. 69 :return: `model_output` transformed into a label 70 """ 71 # normalise to lowercase & strip 72 valid_labels = [label.lower().strip() for label in valid_labels] 73 74 response_words = model_output.split(" ") 75 predicted_labels = [word.lower().strip() for word in response_words if word.lower().strip() in valid_labels] 76 # if there is more than one label in the model output we pick the first 77 string_label = predicted_labels[0] if predicted_labels else UNKNOWN_LABEL 78 79 return string_label
Convert model output to string class label. The model is expected to return a label directly (if it has a
classification head), or a string containing a label (if it has a language modelling head). In the latter case we
strip any additional text (e.g. "The answer is 2." --> "2"). If no valid labels is contained in the
model_output
an "unknown" label is returned. Users can define other converter_fn
s, e.g. to translate a text
label to string ("NEGATIVE" --> "0").
Parameters
- model_output: Value returned by the model.
- valid_labels: Valid labels.
Returns
model_output
transformed into a label
74class ClassificationAccuracySemanticRobustness(EvalAlgorithmInterface): 75 """Semantic Robustness evaluation algorithm for Classification Accuracy 76 77 This evaluation measures how much Classification Accuracy changes as a result of semantic preserving 78 perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text, 79 how much does this impact the ability of the model to correctly classify this text. 80 81 The output difference is measured by computing the Classification Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores 82 on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$ 83 where $s$ is the score produced by the original metric (i.e., accuracy, precision, recall and balanced accuracy), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied. 84 85 For details on the Classification Accuracy metrics, see the Classification Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation. 86 """ 87 88 eval_name = EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS.value 89 90 def __init__( 91 self, 92 eval_algorithm_config: ClassificationAccuracySemanticRobustnessConfig = ClassificationAccuracySemanticRobustnessConfig(), 93 ): 94 """ClassificationAccuracySemanticRobustness initializer. 95 96 :param eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config. 97 """ 98 super().__init__(eval_algorithm_config) 99 self.config = eval_algorithm_config 100 self.perturbation_transform = get_perturbation_transform(eval_algorithm_config) 101 self.valid_labels = eval_algorithm_config.valid_labels 102 self.converter_fn = eval_algorithm_config.converter_fn 103 104 def _build_pipeline( 105 self, 106 model: ModelRunner, 107 prompt_template: str, 108 valid_labels: Optional[List[str]], 109 ) -> TransformPipeline: 110 """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`. 111 112 While other evaluation algorithms (ex: Classification Accuracy) can configure 113 their TransformPipeline at algorithm initialization, because the Classification Accuracy 114 Semantic Robustness algorithm's evaluation logic depends on the ModelRunner 115 and prompt template that are evaluation-specific (i.e. these parameters aren't 116 configured at the algorithm level), the pipeline used by this algorithm is built 117 when `evaluate` or `evaluate_sample` is called. 118 119 :param model: The ModelRunner representing the model under evaluation. 120 :param prompt_template: A template that is used to construct the prompt fed to the model. 121 :param valid_labels: A list of valid labels for the classified model output. 122 :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`. 123 """ 124 get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_outputs = get_model_outputs_from_perturbed_inputs( 125 self.perturbation_transform, 126 prompt_template, 127 model, 128 ) 129 130 original_scores = ClassificationAccuracyScores(valid_labels=valid_labels, converter_fn=self.converter_fn) 131 perturbed_scores = [ 132 ClassificationAccuracyScores( 133 valid_labels=valid_labels, 134 model_output_key=perturbed_output_key, 135 classified_model_output_key=f"{CLASSIFIED_MODEL_OUTPUT_COLUMN_NAME}_perturbed_{i}", 136 classification_accuracy_score_key=f"{CLASSIFICATION_ACCURACY_SCORE}_perturbed_{i}", 137 converter_fn=self.converter_fn, 138 ) 139 for i, perturbed_output_key in enumerate(get_perturbed_outputs.output_keys) 140 ] 141 142 perturbed_score_keys = [ 143 perturbed_score_transform.classification_accuracy_score_key 144 for perturbed_score_transform in perturbed_scores 145 ] 146 mean_delta_scores = MeanDeltaScores( 147 {CLASSIFICATION_ACCURACY_SCORE: (perturbed_score_keys, DELTA_CLASSIFICATION_ACCURACY_SCORE)} 148 ) 149 150 transforms = [ 151 get_perturbed_inputs, 152 gen_perturbed_prompts, 153 get_perturbed_outputs, 154 original_scores, 155 TransformPipeline(perturbed_scores), 156 mean_delta_scores, 157 ] 158 pipeline = TransformPipeline(transforms) 159 return pipeline 160 161 def evaluate_sample( 162 self, 163 model_input: str, 164 target_output: str, 165 model: ModelRunner, 166 prompt_template: str = DEFAULT_PROMPT_TEMPLATE, 167 ) -> List[EvalScore]: 168 """Compute classification accuracy semantic robustness metrics for a single sample. 169 170 A sample is defined as a model input and target output pair. 171 172 :param model_input: Text input, which will be composed into a prompt that gets fed to the model. 173 :param target_output: The expected response from the model. 174 :param model: An instance of ModelRunner representing the model under evaluation. 175 :param prompt_template: A template used to compose the prompt from `model_input`. 176 :return: A list of EvalScores. 177 """ 178 sample = { 179 DatasetColumns.MODEL_INPUT.value.name: model_input, 180 DatasetColumns.TARGET_OUTPUT.value.name: target_output, 181 } 182 invoke_model = create_model_invocation_pipeline(model, prompt_template) 183 compute_metrics = self._build_pipeline(model, prompt_template, self.valid_labels) 184 pipeline = TransformPipeline([invoke_model, compute_metrics]) 185 output_record = pipeline.execute_record(sample) 186 187 original_score = EvalScore( 188 name=CLASSIFICATION_ACCURACY_SCORE, value=output_record[CLASSIFICATION_ACCURACY_SCORE] 189 ) 190 delta_score = EvalScore( 191 name=DELTA_CLASSIFICATION_ACCURACY_SCORE, value=output_record[DELTA_CLASSIFICATION_ACCURACY_SCORE] 192 ) 193 return [original_score, delta_score] 194 195 def evaluate( 196 self, 197 model: ModelRunner, 198 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 199 prompt_template: Optional[str] = None, 200 num_records: int = 100, 201 save: bool = False, 202 save_strategy: Optional[SaveStrategy] = None, 203 ) -> List[EvalOutput]: 204 """Compute classification accuracy semantic robustness metrics on one or more datasets. 205 206 :param model: An instance of ModelRunner representing the model under evaluation. 207 This is a required argument, as even if the dataset contains model outputs, 208 semantic robustness algorithms rely on invoking a model on perturbed inputs 209 to see how the model outputs from the perturbed inputs differ from the original 210 model outputs. 211 :param dataset_config: Configures a single dataset or list of datasets used for the 212 evaluation. If not provided, this method will run evaluations using all of its 213 supported built-in datasets. 214 :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults 215 will be used. 216 :param num_records: The number of records to be sampled randomly from the input dataset to perform the 217 evaluation 218 :param save: If set to true, prompt responses and scores will be saved to a file. 219 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 220 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 221 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 222 223 :returns: A List of EvalOutput objects. 224 """ 225 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 226 eval_outputs: List[EvalOutput] = [] 227 228 for dataset_config in dataset_configs: 229 dataset_prompt_template = ( 230 get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template 231 ) 232 dataset = get_dataset(dataset_config, num_records) 233 validate_dataset(dataset, [DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name]) 234 235 valid_labels = ( 236 self.valid_labels 237 if self.valid_labels 238 else dataset.unique(column=DatasetColumns.TARGET_OUTPUT.value.name) 239 ) 240 row_count = dataset.count() 241 if len(valid_labels) / (row_count + 1) < UNIQUENESS_FACTOR: # pragma: no cover 242 logger.warning( 243 f"The number of classes: {len(valid_labels)} in the dataset is too large " 244 f"for the number of rows in the dataset: {row_count}", 245 ) 246 247 eval_output = evaluate_dataset( 248 dataset=dataset, 249 pipeline=self._build_pipeline(model, dataset_prompt_template, valid_labels), 250 dataset_name=dataset_config.dataset_name, 251 eval_name=self.eval_name, 252 metric_names=[CLASSIFICATION_ACCURACY_SCORE, DELTA_CLASSIFICATION_ACCURACY_SCORE], 253 eval_results_path=get_eval_results_path(), 254 model=model, 255 prompt_template=dataset_prompt_template, 256 agg_method=MEAN, 257 save=save, 258 save_strategy=save_strategy if save_strategy else None, 259 ) 260 eval_outputs.append(eval_output) 261 262 return eval_outputs
Semantic Robustness evaluation algorithm for Classification Accuracy
This evaluation measures how much Classification Accuracy changes as a result of semantic preserving perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text, how much does this impact the ability of the model to correctly classify this text.
The output difference is measured by computing the Classification Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores
on average over N (num_perturbations
) perturbed inputs: $$ rac{1}{P} \sum_{i=1}^{P} |s - ar{s}_i|,$$
where $s$ is the score produced by the original metric (i.e., accuracy, precision, recall and balanced accuracy), and $ar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.
For details on the Classification Accuracy metrics, see the Classification Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.
90 def __init__( 91 self, 92 eval_algorithm_config: ClassificationAccuracySemanticRobustnessConfig = ClassificationAccuracySemanticRobustnessConfig(), 93 ): 94 """ClassificationAccuracySemanticRobustness initializer. 95 96 :param eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config. 97 """ 98 super().__init__(eval_algorithm_config) 99 self.config = eval_algorithm_config 100 self.perturbation_transform = get_perturbation_transform(eval_algorithm_config) 101 self.valid_labels = eval_algorithm_config.valid_labels 102 self.converter_fn = eval_algorithm_config.converter_fn
ClassificationAccuracySemanticRobustness initializer.
Parameters
- eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config.
161 def evaluate_sample( 162 self, 163 model_input: str, 164 target_output: str, 165 model: ModelRunner, 166 prompt_template: str = DEFAULT_PROMPT_TEMPLATE, 167 ) -> List[EvalScore]: 168 """Compute classification accuracy semantic robustness metrics for a single sample. 169 170 A sample is defined as a model input and target output pair. 171 172 :param model_input: Text input, which will be composed into a prompt that gets fed to the model. 173 :param target_output: The expected response from the model. 174 :param model: An instance of ModelRunner representing the model under evaluation. 175 :param prompt_template: A template used to compose the prompt from `model_input`. 176 :return: A list of EvalScores. 177 """ 178 sample = { 179 DatasetColumns.MODEL_INPUT.value.name: model_input, 180 DatasetColumns.TARGET_OUTPUT.value.name: target_output, 181 } 182 invoke_model = create_model_invocation_pipeline(model, prompt_template) 183 compute_metrics = self._build_pipeline(model, prompt_template, self.valid_labels) 184 pipeline = TransformPipeline([invoke_model, compute_metrics]) 185 output_record = pipeline.execute_record(sample) 186 187 original_score = EvalScore( 188 name=CLASSIFICATION_ACCURACY_SCORE, value=output_record[CLASSIFICATION_ACCURACY_SCORE] 189 ) 190 delta_score = EvalScore( 191 name=DELTA_CLASSIFICATION_ACCURACY_SCORE, value=output_record[DELTA_CLASSIFICATION_ACCURACY_SCORE] 192 ) 193 return [original_score, delta_score]
Compute classification accuracy semantic robustness metrics for a single sample.
A sample is defined as a model input and target output pair.
Parameters
- model_input: Text input, which will be composed into a prompt that gets fed to the model.
- target_output: The expected response from the model.
- model: An instance of ModelRunner representing the model under evaluation.
- prompt_template: A template used to compose the prompt from
model_input
.
Returns
A list of EvalScores.
195 def evaluate( 196 self, 197 model: ModelRunner, 198 dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None, 199 prompt_template: Optional[str] = None, 200 num_records: int = 100, 201 save: bool = False, 202 save_strategy: Optional[SaveStrategy] = None, 203 ) -> List[EvalOutput]: 204 """Compute classification accuracy semantic robustness metrics on one or more datasets. 205 206 :param model: An instance of ModelRunner representing the model under evaluation. 207 This is a required argument, as even if the dataset contains model outputs, 208 semantic robustness algorithms rely on invoking a model on perturbed inputs 209 to see how the model outputs from the perturbed inputs differ from the original 210 model outputs. 211 :param dataset_config: Configures a single dataset or list of datasets used for the 212 evaluation. If not provided, this method will run evaluations using all of its 213 supported built-in datasets. 214 :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults 215 will be used. 216 :param num_records: The number of records to be sampled randomly from the input dataset to perform the 217 evaluation 218 :param save: If set to true, prompt responses and scores will be saved to a file. 219 :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not 220 specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. 221 If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`. 222 223 :returns: A List of EvalOutput objects. 224 """ 225 dataset_configs = get_dataset_configs(dataset_config, self.eval_name) 226 eval_outputs: List[EvalOutput] = [] 227 228 for dataset_config in dataset_configs: 229 dataset_prompt_template = ( 230 get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template 231 ) 232 dataset = get_dataset(dataset_config, num_records) 233 validate_dataset(dataset, [DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name]) 234 235 valid_labels = ( 236 self.valid_labels 237 if self.valid_labels 238 else dataset.unique(column=DatasetColumns.TARGET_OUTPUT.value.name) 239 ) 240 row_count = dataset.count() 241 if len(valid_labels) / (row_count + 1) < UNIQUENESS_FACTOR: # pragma: no cover 242 logger.warning( 243 f"The number of classes: {len(valid_labels)} in the dataset is too large " 244 f"for the number of rows in the dataset: {row_count}", 245 ) 246 247 eval_output = evaluate_dataset( 248 dataset=dataset, 249 pipeline=self._build_pipeline(model, dataset_prompt_template, valid_labels), 250 dataset_name=dataset_config.dataset_name, 251 eval_name=self.eval_name, 252 metric_names=[CLASSIFICATION_ACCURACY_SCORE, DELTA_CLASSIFICATION_ACCURACY_SCORE], 253 eval_results_path=get_eval_results_path(), 254 model=model, 255 prompt_template=dataset_prompt_template, 256 agg_method=MEAN, 257 save=save, 258 save_strategy=save_strategy if save_strategy else None, 259 ) 260 eval_outputs.append(eval_output) 261 262 return eval_outputs
Compute classification accuracy semantic robustness metrics on one or more datasets.
Parameters
- model: An instance of ModelRunner representing the model under evaluation. This is a required argument, as even if the dataset contains model outputs, semantic robustness algorithms rely on invoking a model on perturbed inputs to see how the model outputs from the perturbed inputs differ from the original model outputs.
- dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
- prompt_template: A template which can be used to generate prompts, optional, if not provided defaults will be used.
- num_records: The number of records to be sampled randomly from the input dataset to perform the evaluation
- save: If set to true, prompt responses and scores will be saved to a file.
- save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
If that environment variable is also not configured, it will be saved to the default path
/tmp/eval_results/
.
:returns: A List of EvalOutput objects.