fmeval.eval_algorithms.classification_accuracy_semantic_robustness

  1import logging
  2import warnings
  3
  4from typing import Callable, List, Optional, Union
  5from dataclasses import dataclass
  6
  7from fmeval.constants import (
  8    DatasetColumns,
  9    MEAN,
 10)
 11from fmeval.data_loaders.util import get_dataset
 12from fmeval.data_loaders.data_config import DataConfig
 13from fmeval.eval_algorithms.common import evaluate_dataset
 14from fmeval.eval_algorithms.save_strategy import SaveStrategy
 15from fmeval.eval_algorithms.semantic_robustness_utils import (
 16    SemanticRobustnessConfig,
 17    get_perturbation_transform,
 18    get_model_outputs_from_perturbed_inputs,
 19)
 20from fmeval.eval_algorithms.util import (
 21    get_dataset_configs,
 22    validate_dataset,
 23    create_model_invocation_pipeline,
 24)
 25from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface
 26from fmeval.eval_algorithms import (
 27    EvalAlgorithm,
 28    EvalOutput,
 29    EvalScore,
 30    get_default_prompt_template,
 31    DEFAULT_PROMPT_TEMPLATE,
 32)
 33from fmeval.model_runners.model_runner import ModelRunner
 34from fmeval.eval_algorithms.classification_accuracy import (
 35    convert_model_output_to_label,
 36    CLASSIFICATION_ACCURACY_SCORE,
 37    UNIQUENESS_FACTOR,
 38    ClassificationAccuracyScores,
 39    CLASSIFIED_MODEL_OUTPUT_COLUMN_NAME,
 40)
 41from fmeval.transforms.semantic_robustness_metrics import MeanDeltaScores
 42from fmeval.transforms.transform_pipeline import TransformPipeline
 43from fmeval.util import get_eval_results_path
 44
 45PREFIX_FOR_DELTA_SCORES = "delta_"
 46DELTA_CLASSIFICATION_ACCURACY_SCORE = PREFIX_FOR_DELTA_SCORES + CLASSIFICATION_ACCURACY_SCORE
 47
 48logger = logging.getLogger(__name__)
 49
 50
 51@dataclass(frozen=True)
 52class ClassificationAccuracySemanticRobustnessConfig(SemanticRobustnessConfig):
 53    """Configures the Classification Accuracy Semantic Robustness evaluation algorithm.
 54
 55    See SemanticRobustnessConfig for the configurable parameters that this config class inherits.
 56
 57    :param valid_labels: A list of valid labels.
 58    :param converter_fn: Function to process model output to labels. Defaults to simple integer conversion.
 59    """
 60
 61    valid_labels: Optional[List[str]] = None
 62    converter_fn: Callable[[str, List[str]], str] = convert_model_output_to_label
 63
 64    def __post_init__(self):
 65        super().__post_init__()
 66        if self.valid_labels:
 67            for i, label in enumerate(self.valid_labels):
 68                if not isinstance(label, str):
 69                    warnings.warn("Valid labels should be strings, casting.")
 70                    self.valid_labels[i] = str(label)
 71
 72
 73class ClassificationAccuracySemanticRobustness(EvalAlgorithmInterface):
 74    """Semantic Robustness evaluation algorithm for Classification Accuracy
 75
 76    This evaluation measures how much Classification Accuracy changes as a result of semantic preserving
 77    perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text,
 78    how much does this impact the ability of the model to correctly classify this text.
 79
 80    The output difference is measured by computing the Classification Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores
 81    on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$
 82    where $s$ is the score produced by the original metric (i.e., accuracy, precision, recall and balanced accuracy), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.
 83
 84    For details on the Classification Accuracy metrics, see the Classification Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.
 85    """
 86
 87    eval_name = EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS.value
 88
 89    def __init__(
 90        self,
 91        eval_algorithm_config: ClassificationAccuracySemanticRobustnessConfig = ClassificationAccuracySemanticRobustnessConfig(),
 92    ):
 93        """ClassificationAccuracySemanticRobustness initializer.
 94
 95        :param eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config.
 96        """
 97        super().__init__(eval_algorithm_config)
 98        self.config = eval_algorithm_config
 99        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
100        self.valid_labels = eval_algorithm_config.valid_labels
101        self.converter_fn = eval_algorithm_config.converter_fn
102
103    def _build_pipeline(
104        self,
105        model: ModelRunner,
106        prompt_template: str,
107        valid_labels: Optional[List[str]],
108    ) -> TransformPipeline:
109        """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`.
110
111        While other evaluation algorithms (ex: Classification Accuracy) can configure
112        their TransformPipeline at algorithm initialization, because the Classification Accuracy
113        Semantic Robustness algorithm's evaluation logic depends on the ModelRunner
114        and prompt template that are evaluation-specific (i.e. these parameters aren't
115        configured at the algorithm level), the pipeline used by this algorithm is built
116        when `evaluate` or `evaluate_sample` is called.
117
118        :param model: The ModelRunner representing the model under evaluation.
119        :param prompt_template: A template that is used to construct the prompt fed to the model.
120        :param valid_labels: A list of valid labels for the classified model output.
121        :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`.
122        """
123        get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_outputs = get_model_outputs_from_perturbed_inputs(
124            self.perturbation_transform,
125            prompt_template,
126            model,
127        )
128
129        original_scores = ClassificationAccuracyScores(valid_labels=valid_labels, converter_fn=self.converter_fn)
130        perturbed_scores = [
131            ClassificationAccuracyScores(
132                valid_labels=valid_labels,
133                model_output_key=perturbed_output_key,
134                classified_model_output_key=f"{CLASSIFIED_MODEL_OUTPUT_COLUMN_NAME}_perturbed_{i}",
135                classification_accuracy_score_key=f"{CLASSIFICATION_ACCURACY_SCORE}_perturbed_{i}",
136                converter_fn=self.converter_fn,
137            )
138            for i, perturbed_output_key in enumerate(get_perturbed_outputs.output_keys)
139        ]
140
141        perturbed_score_keys = [
142            perturbed_score_transform.classification_accuracy_score_key
143            for perturbed_score_transform in perturbed_scores
144        ]
145        mean_delta_scores = MeanDeltaScores(
146            {CLASSIFICATION_ACCURACY_SCORE: (perturbed_score_keys, DELTA_CLASSIFICATION_ACCURACY_SCORE)}
147        )
148
149        transforms = [
150            get_perturbed_inputs,
151            gen_perturbed_prompts,
152            get_perturbed_outputs,
153            original_scores,
154            TransformPipeline(perturbed_scores),
155            mean_delta_scores,
156        ]
157        pipeline = TransformPipeline(transforms)
158        return pipeline
159
160    def evaluate_sample(
161        self,
162        model_input: str,
163        target_output: str,
164        model: ModelRunner,
165        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
166    ) -> List[EvalScore]:
167        """Compute classification accuracy semantic robustness metrics for a single sample.
168
169        A sample is defined as a model input and target output pair.
170
171        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
172        :param target_output: The expected response from the model.
173        :param model: An instance of ModelRunner representing the model under evaluation.
174        :param prompt_template: A template used to compose the prompt from `model_input`.
175        :return: A list of EvalScores.
176        """
177        sample = {
178            DatasetColumns.MODEL_INPUT.value.name: model_input,
179            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
180        }
181        invoke_model = create_model_invocation_pipeline(model, prompt_template)
182        compute_metrics = self._build_pipeline(model, prompt_template, self.valid_labels)
183        pipeline = TransformPipeline([invoke_model, compute_metrics])
184        output_record = pipeline.execute_record(sample)
185
186        original_score = EvalScore(
187            name=CLASSIFICATION_ACCURACY_SCORE, value=output_record[CLASSIFICATION_ACCURACY_SCORE]
188        )
189        delta_score = EvalScore(
190            name=DELTA_CLASSIFICATION_ACCURACY_SCORE, value=output_record[DELTA_CLASSIFICATION_ACCURACY_SCORE]
191        )
192        return [original_score, delta_score]
193
194    def evaluate(
195        self,
196        model: ModelRunner,
197        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
198        prompt_template: Optional[str] = None,
199        num_records: int = 100,
200        save: bool = False,
201        save_strategy: Optional[SaveStrategy] = None,
202    ) -> List[EvalOutput]:
203        """Compute classification accuracy semantic robustness metrics on one or more datasets.
204
205        :param model: An instance of ModelRunner representing the model under evaluation.
206            This is a required argument, as even if the dataset contains model outputs,
207            semantic robustness algorithms rely on invoking a model on perturbed inputs
208            to see how the model outputs from the perturbed inputs differ from the original
209            model outputs.
210        :param dataset_config: Configures a single dataset or list of datasets used for the
211            evaluation. If not provided, this method will run evaluations using all of its
212            supported built-in datasets.
213        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
214            will be used.
215        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
216                            evaluation
217        :param save: If set to true, prompt responses and scores will be saved to a file.
218        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
219            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
220            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
221
222        :returns: A List of EvalOutput objects.
223        """
224        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
225        eval_outputs: List[EvalOutput] = []
226
227        for dataset_config in dataset_configs:
228            dataset_prompt_template = (
229                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
230            )
231            dataset = get_dataset(dataset_config, num_records)
232            validate_dataset(dataset, [DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name])
233
234            valid_labels = (
235                self.valid_labels
236                if self.valid_labels
237                else dataset.unique(column=DatasetColumns.TARGET_OUTPUT.value.name)
238            )
239            row_count = dataset.count()
240            if len(valid_labels) / (row_count + 1) < UNIQUENESS_FACTOR:  # pragma: no cover
241                logger.warning(
242                    f"The number of classes: {len(valid_labels)} in the dataset is too large "
243                    f"for the number of rows in the dataset: {row_count}",
244                )
245
246            eval_output = evaluate_dataset(
247                dataset=dataset,
248                pipeline=self._build_pipeline(model, dataset_prompt_template, valid_labels),
249                dataset_name=dataset_config.dataset_name,
250                eval_name=self.eval_name,
251                metric_names=[CLASSIFICATION_ACCURACY_SCORE, DELTA_CLASSIFICATION_ACCURACY_SCORE],
252                eval_results_path=get_eval_results_path(),
253                model=model,
254                prompt_template=dataset_prompt_template,
255                agg_method=MEAN,
256                save=save,
257                save_strategy=save_strategy if save_strategy else None,
258            )
259            eval_outputs.append(eval_output)
260
261        return eval_outputs
PREFIX_FOR_DELTA_SCORES = 'delta_'
DELTA_CLASSIFICATION_ACCURACY_SCORE = 'delta_classification_accuracy_score'
@dataclass(frozen=True)
class ClassificationAccuracySemanticRobustnessConfig(fmeval.eval_algorithms.semantic_robustness_utils.SemanticRobustnessConfig):
52@dataclass(frozen=True)
53class ClassificationAccuracySemanticRobustnessConfig(SemanticRobustnessConfig):
54    """Configures the Classification Accuracy Semantic Robustness evaluation algorithm.
55
56    See SemanticRobustnessConfig for the configurable parameters that this config class inherits.
57
58    :param valid_labels: A list of valid labels.
59    :param converter_fn: Function to process model output to labels. Defaults to simple integer conversion.
60    """
61
62    valid_labels: Optional[List[str]] = None
63    converter_fn: Callable[[str, List[str]], str] = convert_model_output_to_label
64
65    def __post_init__(self):
66        super().__post_init__()
67        if self.valid_labels:
68            for i, label in enumerate(self.valid_labels):
69                if not isinstance(label, str):
70                    warnings.warn("Valid labels should be strings, casting.")
71                    self.valid_labels[i] = str(label)

Configures the Classification Accuracy Semantic Robustness evaluation algorithm.

See SemanticRobustnessConfig for the configurable parameters that this config class inherits.

Parameters
  • valid_labels: A list of valid labels.
  • converter_fn: Function to process model output to labels. Defaults to simple integer conversion.
ClassificationAccuracySemanticRobustnessConfig( perturbation_type: str = 'butter_finger', num_perturbations: int = 5, butter_finger_perturbation_prob: float = 0.1, random_uppercase_corrupt_proportion: float = 0.1, whitespace_add_prob: float = 0.05, whitespace_remove_prob: float = 0.1, valid_labels: Optional[List[str]] = None, converter_fn: Callable[[str, List[str]], str] = <function convert_model_output_to_label>)
valid_labels: Optional[List[str]] = None
def converter_fn(model_output: str, valid_labels: List[str]) -> str:
60def convert_model_output_to_label(model_output: str, valid_labels: List[str]) -> str:
61    """Convert model output to string class label. The model is expected to return a label directly (if it has a
62    classification head), or a string containing a label (if it has a language modelling head). In the latter case we
63    strip any additional text (e.g. "The answer is 2." --> "2"). If no valid labels is contained in the
64    `model_output` an "unknown" label is returned. Users can define other `converter_fn`s, e.g. to translate a text
65    label to string ("NEGATIVE" --> "0").
66
67    :param model_output: Value returned by the model.
68    :param valid_labels: Valid labels.
69    :return: `model_output` transformed into a label
70    """
71    # normalise to lowercase & strip
72    valid_labels = [label.lower().strip() for label in valid_labels]
73
74    response_words = model_output.split(" ")
75    predicted_labels = [word.lower().strip() for word in response_words if word.lower().strip() in valid_labels]
76    # if there is more than one label in the model output we pick the first
77    string_label = predicted_labels[0] if predicted_labels else UNKNOWN_LABEL
78
79    return string_label

Convert model output to string class label. The model is expected to return a label directly (if it has a classification head), or a string containing a label (if it has a language modelling head). In the latter case we strip any additional text (e.g. "The answer is 2." --> "2"). If no valid labels is contained in the model_output an "unknown" label is returned. Users can define other converter_fns, e.g. to translate a text label to string ("NEGATIVE" --> "0").

Parameters
  • model_output: Value returned by the model.
  • valid_labels: Valid labels.
Returns

model_output transformed into a label

class ClassificationAccuracySemanticRobustness(fmeval.eval_algorithms.eval_algorithm.EvalAlgorithmInterface):
 74class ClassificationAccuracySemanticRobustness(EvalAlgorithmInterface):
 75    """Semantic Robustness evaluation algorithm for Classification Accuracy
 76
 77    This evaluation measures how much Classification Accuracy changes as a result of semantic preserving
 78    perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text,
 79    how much does this impact the ability of the model to correctly classify this text.
 80
 81    The output difference is measured by computing the Classification Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores
 82    on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$
 83    where $s$ is the score produced by the original metric (i.e., accuracy, precision, recall and balanced accuracy), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.
 84
 85    For details on the Classification Accuracy metrics, see the Classification Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.
 86    """
 87
 88    eval_name = EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS.value
 89
 90    def __init__(
 91        self,
 92        eval_algorithm_config: ClassificationAccuracySemanticRobustnessConfig = ClassificationAccuracySemanticRobustnessConfig(),
 93    ):
 94        """ClassificationAccuracySemanticRobustness initializer.
 95
 96        :param eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config.
 97        """
 98        super().__init__(eval_algorithm_config)
 99        self.config = eval_algorithm_config
100        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
101        self.valid_labels = eval_algorithm_config.valid_labels
102        self.converter_fn = eval_algorithm_config.converter_fn
103
104    def _build_pipeline(
105        self,
106        model: ModelRunner,
107        prompt_template: str,
108        valid_labels: Optional[List[str]],
109    ) -> TransformPipeline:
110        """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`.
111
112        While other evaluation algorithms (ex: Classification Accuracy) can configure
113        their TransformPipeline at algorithm initialization, because the Classification Accuracy
114        Semantic Robustness algorithm's evaluation logic depends on the ModelRunner
115        and prompt template that are evaluation-specific (i.e. these parameters aren't
116        configured at the algorithm level), the pipeline used by this algorithm is built
117        when `evaluate` or `evaluate_sample` is called.
118
119        :param model: The ModelRunner representing the model under evaluation.
120        :param prompt_template: A template that is used to construct the prompt fed to the model.
121        :param valid_labels: A list of valid labels for the classified model output.
122        :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`.
123        """
124        get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_outputs = get_model_outputs_from_perturbed_inputs(
125            self.perturbation_transform,
126            prompt_template,
127            model,
128        )
129
130        original_scores = ClassificationAccuracyScores(valid_labels=valid_labels, converter_fn=self.converter_fn)
131        perturbed_scores = [
132            ClassificationAccuracyScores(
133                valid_labels=valid_labels,
134                model_output_key=perturbed_output_key,
135                classified_model_output_key=f"{CLASSIFIED_MODEL_OUTPUT_COLUMN_NAME}_perturbed_{i}",
136                classification_accuracy_score_key=f"{CLASSIFICATION_ACCURACY_SCORE}_perturbed_{i}",
137                converter_fn=self.converter_fn,
138            )
139            for i, perturbed_output_key in enumerate(get_perturbed_outputs.output_keys)
140        ]
141
142        perturbed_score_keys = [
143            perturbed_score_transform.classification_accuracy_score_key
144            for perturbed_score_transform in perturbed_scores
145        ]
146        mean_delta_scores = MeanDeltaScores(
147            {CLASSIFICATION_ACCURACY_SCORE: (perturbed_score_keys, DELTA_CLASSIFICATION_ACCURACY_SCORE)}
148        )
149
150        transforms = [
151            get_perturbed_inputs,
152            gen_perturbed_prompts,
153            get_perturbed_outputs,
154            original_scores,
155            TransformPipeline(perturbed_scores),
156            mean_delta_scores,
157        ]
158        pipeline = TransformPipeline(transforms)
159        return pipeline
160
161    def evaluate_sample(
162        self,
163        model_input: str,
164        target_output: str,
165        model: ModelRunner,
166        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
167    ) -> List[EvalScore]:
168        """Compute classification accuracy semantic robustness metrics for a single sample.
169
170        A sample is defined as a model input and target output pair.
171
172        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
173        :param target_output: The expected response from the model.
174        :param model: An instance of ModelRunner representing the model under evaluation.
175        :param prompt_template: A template used to compose the prompt from `model_input`.
176        :return: A list of EvalScores.
177        """
178        sample = {
179            DatasetColumns.MODEL_INPUT.value.name: model_input,
180            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
181        }
182        invoke_model = create_model_invocation_pipeline(model, prompt_template)
183        compute_metrics = self._build_pipeline(model, prompt_template, self.valid_labels)
184        pipeline = TransformPipeline([invoke_model, compute_metrics])
185        output_record = pipeline.execute_record(sample)
186
187        original_score = EvalScore(
188            name=CLASSIFICATION_ACCURACY_SCORE, value=output_record[CLASSIFICATION_ACCURACY_SCORE]
189        )
190        delta_score = EvalScore(
191            name=DELTA_CLASSIFICATION_ACCURACY_SCORE, value=output_record[DELTA_CLASSIFICATION_ACCURACY_SCORE]
192        )
193        return [original_score, delta_score]
194
195    def evaluate(
196        self,
197        model: ModelRunner,
198        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
199        prompt_template: Optional[str] = None,
200        num_records: int = 100,
201        save: bool = False,
202        save_strategy: Optional[SaveStrategy] = None,
203    ) -> List[EvalOutput]:
204        """Compute classification accuracy semantic robustness metrics on one or more datasets.
205
206        :param model: An instance of ModelRunner representing the model under evaluation.
207            This is a required argument, as even if the dataset contains model outputs,
208            semantic robustness algorithms rely on invoking a model on perturbed inputs
209            to see how the model outputs from the perturbed inputs differ from the original
210            model outputs.
211        :param dataset_config: Configures a single dataset or list of datasets used for the
212            evaluation. If not provided, this method will run evaluations using all of its
213            supported built-in datasets.
214        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
215            will be used.
216        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
217                            evaluation
218        :param save: If set to true, prompt responses and scores will be saved to a file.
219        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
220            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
221            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
222
223        :returns: A List of EvalOutput objects.
224        """
225        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
226        eval_outputs: List[EvalOutput] = []
227
228        for dataset_config in dataset_configs:
229            dataset_prompt_template = (
230                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
231            )
232            dataset = get_dataset(dataset_config, num_records)
233            validate_dataset(dataset, [DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name])
234
235            valid_labels = (
236                self.valid_labels
237                if self.valid_labels
238                else dataset.unique(column=DatasetColumns.TARGET_OUTPUT.value.name)
239            )
240            row_count = dataset.count()
241            if len(valid_labels) / (row_count + 1) < UNIQUENESS_FACTOR:  # pragma: no cover
242                logger.warning(
243                    f"The number of classes: {len(valid_labels)} in the dataset is too large "
244                    f"for the number of rows in the dataset: {row_count}",
245                )
246
247            eval_output = evaluate_dataset(
248                dataset=dataset,
249                pipeline=self._build_pipeline(model, dataset_prompt_template, valid_labels),
250                dataset_name=dataset_config.dataset_name,
251                eval_name=self.eval_name,
252                metric_names=[CLASSIFICATION_ACCURACY_SCORE, DELTA_CLASSIFICATION_ACCURACY_SCORE],
253                eval_results_path=get_eval_results_path(),
254                model=model,
255                prompt_template=dataset_prompt_template,
256                agg_method=MEAN,
257                save=save,
258                save_strategy=save_strategy if save_strategy else None,
259            )
260            eval_outputs.append(eval_output)
261
262        return eval_outputs

Semantic Robustness evaluation algorithm for Classification Accuracy

This evaluation measures how much Classification Accuracy changes as a result of semantic preserving perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text, how much does this impact the ability of the model to correctly classify this text.

The output difference is measured by computing the Classification Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores on average over N (num_perturbations) perturbed inputs: $$ rac{1}{P} \sum_{i=1}^{P} |s - ar{s}_i|,$$ where $s$ is the score produced by the original metric (i.e., accuracy, precision, recall and balanced accuracy), and $ar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.

For details on the Classification Accuracy metrics, see the Classification Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.

ClassificationAccuracySemanticRobustness( eval_algorithm_config: ClassificationAccuracySemanticRobustnessConfig = ClassificationAccuracySemanticRobustnessConfig(perturbation_type='butter_finger', num_perturbations=5, butter_finger_perturbation_prob=0.1, random_uppercase_corrupt_proportion=0.1, whitespace_add_prob=0.05, whitespace_remove_prob=0.1, valid_labels=None, converter_fn=<function convert_model_output_to_label>))
 90    def __init__(
 91        self,
 92        eval_algorithm_config: ClassificationAccuracySemanticRobustnessConfig = ClassificationAccuracySemanticRobustnessConfig(),
 93    ):
 94        """ClassificationAccuracySemanticRobustness initializer.
 95
 96        :param eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config.
 97        """
 98        super().__init__(eval_algorithm_config)
 99        self.config = eval_algorithm_config
100        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
101        self.valid_labels = eval_algorithm_config.valid_labels
102        self.converter_fn = eval_algorithm_config.converter_fn

ClassificationAccuracySemanticRobustness initializer.

Parameters
  • eval_algorithm_config: Classification Accuracy Semantic Robustness evaluation algorithm config.
eval_name = 'classification_accuracy_semantic_robustness'
config
perturbation_transform
valid_labels
converter_fn
def evaluate_sample( self, model_input: str, target_output: str, model: fmeval.model_runners.model_runner.ModelRunner, prompt_template: str = '$model_input') -> List[fmeval.eval_algorithms.EvalScore]:
161    def evaluate_sample(
162        self,
163        model_input: str,
164        target_output: str,
165        model: ModelRunner,
166        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
167    ) -> List[EvalScore]:
168        """Compute classification accuracy semantic robustness metrics for a single sample.
169
170        A sample is defined as a model input and target output pair.
171
172        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
173        :param target_output: The expected response from the model.
174        :param model: An instance of ModelRunner representing the model under evaluation.
175        :param prompt_template: A template used to compose the prompt from `model_input`.
176        :return: A list of EvalScores.
177        """
178        sample = {
179            DatasetColumns.MODEL_INPUT.value.name: model_input,
180            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
181        }
182        invoke_model = create_model_invocation_pipeline(model, prompt_template)
183        compute_metrics = self._build_pipeline(model, prompt_template, self.valid_labels)
184        pipeline = TransformPipeline([invoke_model, compute_metrics])
185        output_record = pipeline.execute_record(sample)
186
187        original_score = EvalScore(
188            name=CLASSIFICATION_ACCURACY_SCORE, value=output_record[CLASSIFICATION_ACCURACY_SCORE]
189        )
190        delta_score = EvalScore(
191            name=DELTA_CLASSIFICATION_ACCURACY_SCORE, value=output_record[DELTA_CLASSIFICATION_ACCURACY_SCORE]
192        )
193        return [original_score, delta_score]

Compute classification accuracy semantic robustness metrics for a single sample.

A sample is defined as a model input and target output pair.

Parameters
  • model_input: Text input, which will be composed into a prompt that gets fed to the model.
  • target_output: The expected response from the model.
  • model: An instance of ModelRunner representing the model under evaluation.
  • prompt_template: A template used to compose the prompt from model_input.
Returns

A list of EvalScores.

def evaluate( self, model: fmeval.model_runners.model_runner.ModelRunner, dataset_config: Union[fmeval.data_loaders.data_config.DataConfig, List[fmeval.data_loaders.data_config.DataConfig], NoneType] = None, prompt_template: Optional[str] = None, num_records: int = 100, save: bool = False, save_strategy: Optional[fmeval.eval_algorithms.save_strategy.SaveStrategy] = None) -> List[fmeval.eval_algorithms.EvalOutput]:
195    def evaluate(
196        self,
197        model: ModelRunner,
198        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
199        prompt_template: Optional[str] = None,
200        num_records: int = 100,
201        save: bool = False,
202        save_strategy: Optional[SaveStrategy] = None,
203    ) -> List[EvalOutput]:
204        """Compute classification accuracy semantic robustness metrics on one or more datasets.
205
206        :param model: An instance of ModelRunner representing the model under evaluation.
207            This is a required argument, as even if the dataset contains model outputs,
208            semantic robustness algorithms rely on invoking a model on perturbed inputs
209            to see how the model outputs from the perturbed inputs differ from the original
210            model outputs.
211        :param dataset_config: Configures a single dataset or list of datasets used for the
212            evaluation. If not provided, this method will run evaluations using all of its
213            supported built-in datasets.
214        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
215            will be used.
216        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
217                            evaluation
218        :param save: If set to true, prompt responses and scores will be saved to a file.
219        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
220            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
221            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
222
223        :returns: A List of EvalOutput objects.
224        """
225        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
226        eval_outputs: List[EvalOutput] = []
227
228        for dataset_config in dataset_configs:
229            dataset_prompt_template = (
230                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
231            )
232            dataset = get_dataset(dataset_config, num_records)
233            validate_dataset(dataset, [DatasetColumns.TARGET_OUTPUT.value.name, DatasetColumns.MODEL_INPUT.value.name])
234
235            valid_labels = (
236                self.valid_labels
237                if self.valid_labels
238                else dataset.unique(column=DatasetColumns.TARGET_OUTPUT.value.name)
239            )
240            row_count = dataset.count()
241            if len(valid_labels) / (row_count + 1) < UNIQUENESS_FACTOR:  # pragma: no cover
242                logger.warning(
243                    f"The number of classes: {len(valid_labels)} in the dataset is too large "
244                    f"for the number of rows in the dataset: {row_count}",
245                )
246
247            eval_output = evaluate_dataset(
248                dataset=dataset,
249                pipeline=self._build_pipeline(model, dataset_prompt_template, valid_labels),
250                dataset_name=dataset_config.dataset_name,
251                eval_name=self.eval_name,
252                metric_names=[CLASSIFICATION_ACCURACY_SCORE, DELTA_CLASSIFICATION_ACCURACY_SCORE],
253                eval_results_path=get_eval_results_path(),
254                model=model,
255                prompt_template=dataset_prompt_template,
256                agg_method=MEAN,
257                save=save,
258                save_strategy=save_strategy if save_strategy else None,
259            )
260            eval_outputs.append(eval_output)
261
262        return eval_outputs

Compute classification accuracy semantic robustness metrics on one or more datasets.

Parameters
  • model: An instance of ModelRunner representing the model under evaluation. This is a required argument, as even if the dataset contains model outputs, semantic robustness algorithms rely on invoking a model on perturbed inputs to see how the model outputs from the perturbed inputs differ from the original model outputs.
  • dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
  • prompt_template: A template which can be used to generate prompts, optional, if not provided defaults will be used.
  • num_records: The number of records to be sampled randomly from the input dataset to perform the evaluation
  • save: If set to true, prompt responses and scores will be saved to a file.
  • save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. If that environment variable is also not configured, it will be saved to the default path /tmp/eval_results/.

:returns: A List of EvalOutput objects.