fmeval.eval_algorithms.qa_accuracy_semantic_robustness

  1import logging
  2
  3from typing import List, Optional, Union
  4from dataclasses import dataclass
  5
  6from ray.actor import ActorHandle
  7from fmeval.constants import (
  8    DatasetColumns,
  9    MEAN,
 10    PREFIX_FOR_DELTA_SCORES,
 11    BERTSCORE_DEFAULT_MODEL,
 12)
 13from fmeval.data_loaders.util import get_dataset
 14from fmeval.data_loaders.data_config import DataConfig
 15from fmeval.eval_algorithms.common import evaluate_dataset
 16from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModelTypes, BertscoreHelperModel
 17from fmeval.eval_algorithms.save_strategy import SaveStrategy
 18from fmeval.eval_algorithms.semantic_robustness_utils import (
 19    SemanticRobustnessConfig,
 20    get_perturbation_transform,
 21    get_model_outputs_from_perturbed_inputs,
 22)
 23from fmeval.eval_algorithms.util import validate_dataset, create_model_invocation_pipeline, get_dataset_configs
 24from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface
 25from fmeval.eval_algorithms import (
 26    EvalAlgorithm,
 27    EvalOutput,
 28    EvalScore,
 29    get_default_prompt_template,
 30    DEFAULT_PROMPT_TEMPLATE,
 31)
 32from fmeval.model_runners.model_runner import ModelRunner
 33from fmeval.eval_algorithms.qa_accuracy import (
 34    F1_SCORE,
 35    EXACT_MATCH_SCORE,
 36    QUASI_EXACT_MATCH_SCORE,
 37    PRECISION_OVER_WORDS,
 38    RECALL_OVER_WORDS,
 39    QAAccuracyScores,
 40    QA_ACCURACY_SCORE_NAMES,
 41    POSSIBLE_TARGETS,
 42)
 43from fmeval.transforms.common import SplitWithDelimiter
 44from fmeval.transforms.semantic_robustness_metrics import MeanDeltaScores
 45from fmeval.transforms.summarization_accuracy_metrics import BERT_SCORE, BertScore
 46from fmeval.transforms.transform_pipeline import TransformPipeline
 47from fmeval.transforms.util import create_output_key
 48from fmeval.util import require, get_eval_results_path, create_shared_resource, cleanup_shared_resource
 49
 50DELTA_F1_SCORE = PREFIX_FOR_DELTA_SCORES + F1_SCORE
 51DELTA_EXACT_MATCH_SCORE = PREFIX_FOR_DELTA_SCORES + EXACT_MATCH_SCORE
 52DELTA_QUASI_EXACT_MATCH_SCORE = PREFIX_FOR_DELTA_SCORES + QUASI_EXACT_MATCH_SCORE
 53DELTA_PRECISION_OVER_WORDS = PREFIX_FOR_DELTA_SCORES + PRECISION_OVER_WORDS
 54DELTA_RECALL_OVER_WORDS = PREFIX_FOR_DELTA_SCORES + RECALL_OVER_WORDS
 55DELTA_BERT_SCORE = PREFIX_FOR_DELTA_SCORES + BERT_SCORE
 56DELTA_SCORES = [
 57    DELTA_F1_SCORE,
 58    DELTA_EXACT_MATCH_SCORE,
 59    DELTA_QUASI_EXACT_MATCH_SCORE,
 60    DELTA_PRECISION_OVER_WORDS,
 61    DELTA_RECALL_OVER_WORDS,
 62    DELTA_BERT_SCORE,
 63]
 64ORIGINAL_SCORES = [
 65    F1_SCORE,
 66    EXACT_MATCH_SCORE,
 67    QUASI_EXACT_MATCH_SCORE,
 68    PRECISION_OVER_WORDS,
 69    RECALL_OVER_WORDS,
 70    BERT_SCORE,
 71]
 72
 73
 74logger = logging.getLogger(__name__)
 75
 76
 77@dataclass(frozen=True)
 78class QAAccuracySemanticRobustnessConfig(SemanticRobustnessConfig):
 79    """Configures the QA Accuracy Semantic Robustness evaluation algorithm.
 80
 81    See SemanticRobustnessConfig for the configurable parameters that this config class inherits.
 82
 83    :param target_output_delimiter: Target Output can have multiple answers. We expect customer to combine all the
 84        possible answers into a single string and use the delimiter to separate them. For instance,
 85        if the answers are ["UK", "England"] and the delimiter="<OR>", then the target_output should be "UK<OR>England".
 86    :param model_type_for_bertscore: BERT model type to use for computing BERT score.
 87    """
 88
 89    target_output_delimiter: Optional[str] = "<OR>"
 90    model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL
 91
 92    def __post_init__(self):
 93        super().__post_init__()
 94        require(
 95            self.target_output_delimiter != "",
 96            "Empty target_output_delimiter is provided. "
 97            "Please either provide a non-empty string, or set it to None.",
 98        )
 99        require(
100            BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore),
101            f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in "
102            f"QAAccuracySemanticRobustnessConfig. Please choose from acceptable values: "
103            f"{BertscoreHelperModelTypes.model_list()}.",
104        )
105
106
107class QAAccuracySemanticRobustness(EvalAlgorithmInterface):
108    """Semantic Robustness evaluation algorithm for QA Accuracy
109
110    This evaluation measures how much QA Accuracy changes as a result of semantic preserving
111    perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text,
112    how much does the quality of the model answer change.
113
114    The output difference is measured by computing the QA Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores
115    on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$
116    where $s$ is the score produced by the original metric (i.e., exact match, quasi-exact match, precision over words, recall over words and F1 over words), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.
117
118    For details on the QA Accuracy metrics, see the QA Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.
119    """
120
121    eval_name = EvalAlgorithm.QA_ACCURACY_SEMANTIC_ROBUSTNESS.value
122
123    def __init__(
124        self, eval_algorithm_config: QAAccuracySemanticRobustnessConfig = QAAccuracySemanticRobustnessConfig()
125    ):
126        """QAAccuracySemanticRobustness initializer.
127
128        :param eval_algorithm_config: QA Accuracy Semantic Robustness evaluation algorithm config.
129        """
130        super().__init__(eval_algorithm_config)
131        self.config = eval_algorithm_config
132        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
133        self.target_output_delimiter = eval_algorithm_config.target_output_delimiter
134        bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
135        self.bertscore_model = bertscore_model
136
137    def _build_pipeline(
138        self, model: ModelRunner, prompt_template: str, bertscore_model: Union[BertscoreHelperModel, ActorHandle]
139    ) -> TransformPipeline:
140        """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`.
141
142        While other evaluation algorithms (ex: QA Accuracy) can configure
143        their TransformPipeline at algorithm initialization, because the QA Accuracy
144        Semantic Robustness algorithm's evaluation logic depends on the ModelRunner
145        and prompt template that are evaluation-specific (i.e. these parameters aren't
146        configured at the algorithm level), the pipeline used by this algorithm is built
147        when `evaluate` or `evaluate_sample` is called.
148
149        :param model: The ModelRunner representing the model under evaluation.
150        :param prompt_template: A template that is used to construct the prompt fed to the model.
151        :param bertscore_model: Either a BertscoreHelperModel instance or a Ray actor handle corresponding
152            to a BertscoreHelperModel (i.e. a shared resource).
153        :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`.
154        """
155        transforms = get_model_outputs_from_perturbed_inputs(
156            self.perturbation_transform,
157            prompt_template,
158            model,
159        )
160        get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_outputs = transforms
161
162        original_scores = QAAccuracyScores(target_output_delimiter=self.target_output_delimiter)
163        perturbed_scores = [
164            QAAccuracyScores(
165                model_output_key=perturbed_output_key,
166                target_output_delimiter=self.target_output_delimiter,
167                output_keys=[
168                    create_output_key(score_name, perturbed_output_key) for score_name in QA_ACCURACY_SCORE_NAMES
169                ],
170            )
171            for perturbed_output_key in get_perturbed_outputs.output_keys
172        ]
173
174        split_transform = SplitWithDelimiter(
175            input_key=DatasetColumns.TARGET_OUTPUT.value.name,
176            output_key=POSSIBLE_TARGETS,
177            target_output_delimiter=self.target_output_delimiter,
178        )
179        bert_score = BertScore(
180            target_output_keys=None,
181            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
182            output_keys=[BERT_SCORE],
183            allow_duplicate_input_keys=True,
184            target_output_keys_provider=POSSIBLE_TARGETS,
185            bertscore_model=bertscore_model,
186        )
187
188        perturbed_bert_score = BertScore(
189            target_output_keys=None,
190            model_output_keys=get_perturbed_outputs.output_keys,
191            output_keys=[
192                create_output_key(BertScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
193            ],
194            allow_duplicate_input_keys=True,
195            target_output_keys_provider=POSSIBLE_TARGETS,
196            bertscore_model=bertscore_model,
197        )
198
199        key_mapping = {
200            original_score_name: (
201                [perturbed_score_transform.output_keys[i] for perturbed_score_transform in perturbed_scores],
202                DELTA_SCORES[i],
203            )
204            for i, original_score_name in enumerate(QA_ACCURACY_SCORE_NAMES)
205        }
206
207        mean_delta_scores = MeanDeltaScores(key_mapping)
208
209        # key mapping + mean of bert scores
210        mean_delta_bert_scores = MeanDeltaScores(
211            {
212                bert_score.output_keys[0]: (perturbed_bert_score.output_keys, DELTA_BERT_SCORE),
213            }
214        )
215
216        transforms = [
217            get_perturbed_inputs,
218            gen_perturbed_prompts,
219            get_perturbed_outputs,
220            original_scores,
221            split_transform,
222            bert_score,
223            TransformPipeline(perturbed_scores),
224            perturbed_bert_score,
225            mean_delta_scores,
226            mean_delta_bert_scores,
227        ]
228        pipeline = TransformPipeline(transforms)
229        return pipeline
230
231    def evaluate_sample(
232        self,
233        model_input: str,
234        target_output: str,
235        model: ModelRunner,
236        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
237    ) -> List[EvalScore]:
238        """Compute question answering accuracy semantic robustness metrics for a single sample.
239
240        A sample is defined as a model input and target output pair.
241
242        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
243        :param target_output: The expected response from the model.
244        :param model: An instance of ModelRunner representing the model under evaluation.
245        :param prompt_template: A template used to compose the prompt from `model_input`.
246        :return: A list of EvalScores.
247        """
248        sample = {
249            DatasetColumns.MODEL_INPUT.value.name: model_input,
250            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
251        }
252        invoke_model = create_model_invocation_pipeline(model, prompt_template)
253        compute_metrics = self._build_pipeline(model, prompt_template, self.bertscore_model)
254        pipeline = TransformPipeline([invoke_model, compute_metrics])
255        output_record = pipeline.execute_record(sample)
256
257        original_scores = [
258            EvalScore(name=score_name, value=output_record[score_name]) for score_name in ORIGINAL_SCORES
259        ]
260        delta_scores = [
261            EvalScore(name=delta_score_name, value=output_record[delta_score_name]) for delta_score_name in DELTA_SCORES
262        ]
263        return original_scores + delta_scores
264
265    def evaluate(
266        self,
267        model: ModelRunner,
268        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
269        prompt_template: Optional[str] = None,
270        num_records: int = 100,
271        save: bool = False,
272        save_strategy: Optional[SaveStrategy] = None,
273    ) -> List[EvalOutput]:
274        """Compute QA accuracy semantic robustness metrics on one or more datasets.
275
276        :param model: An instance of ModelRunner representing the model under evaluation.
277            This is a required argument, as even if the dataset contains model outputs,
278            semantic robustness algorithms rely on invoking a model on perturbed inputs
279            to see how the model outputs from the perturbed inputs differ from the original
280            model outputs.
281        :param dataset_config: Configures a single dataset or list of datasets used for the
282            evaluation. If not provided, this method will run evaluations using all of its
283            supported built-in datasets.
284        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
285            will be used.
286        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
287                            evaluation
288        :param save: If set to true, prompt responses and scores will be saved to a file.
289        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
290            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
291            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
292        :returns: A List of EvalOutput objects.
293        """
294        # Create a shared resource to be used during the evaluation.
295        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
296
297        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
298        eval_outputs = []
299        for dataset_config in dataset_configs:
300            dataset_prompt_template = (
301                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
302            )
303            dataset = get_dataset(dataset_config, num_records)
304            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
305            eval_output = evaluate_dataset(
306                dataset=dataset,
307                pipeline=self._build_pipeline(model, dataset_prompt_template, bertscore_shared_resource),
308                dataset_name=dataset_config.dataset_name,
309                eval_name=self.eval_name,
310                metric_names=ORIGINAL_SCORES + DELTA_SCORES,
311                eval_results_path=get_eval_results_path(),
312                model=model,
313                prompt_template=dataset_prompt_template,
314                agg_method=MEAN,
315                save=save,
316                save_strategy=save_strategy,
317            )
318            eval_outputs.append(eval_output)
319        cleanup_shared_resource(bertscore_shared_resource)
320        return eval_outputs
DELTA_F1_SCORE = 'delta_f1_score'
DELTA_EXACT_MATCH_SCORE = 'delta_exact_match_score'
DELTA_QUASI_EXACT_MATCH_SCORE = 'delta_quasi_exact_match_score'
DELTA_PRECISION_OVER_WORDS = 'delta_precision_over_words'
DELTA_RECALL_OVER_WORDS = 'delta_recall_over_words'
DELTA_BERT_SCORE = 'delta_bertscore'
DELTA_SCORES = ['delta_f1_score', 'delta_exact_match_score', 'delta_quasi_exact_match_score', 'delta_precision_over_words', 'delta_recall_over_words', 'delta_bertscore']
ORIGINAL_SCORES = ['f1_score', 'exact_match_score', 'quasi_exact_match_score', 'precision_over_words', 'recall_over_words', 'bertscore']
@dataclass(frozen=True)
class QAAccuracySemanticRobustnessConfig(fmeval.eval_algorithms.semantic_robustness_utils.SemanticRobustnessConfig):
 78@dataclass(frozen=True)
 79class QAAccuracySemanticRobustnessConfig(SemanticRobustnessConfig):
 80    """Configures the QA Accuracy Semantic Robustness evaluation algorithm.
 81
 82    See SemanticRobustnessConfig for the configurable parameters that this config class inherits.
 83
 84    :param target_output_delimiter: Target Output can have multiple answers. We expect customer to combine all the
 85        possible answers into a single string and use the delimiter to separate them. For instance,
 86        if the answers are ["UK", "England"] and the delimiter="<OR>", then the target_output should be "UK<OR>England".
 87    :param model_type_for_bertscore: BERT model type to use for computing BERT score.
 88    """
 89
 90    target_output_delimiter: Optional[str] = "<OR>"
 91    model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL
 92
 93    def __post_init__(self):
 94        super().__post_init__()
 95        require(
 96            self.target_output_delimiter != "",
 97            "Empty target_output_delimiter is provided. "
 98            "Please either provide a non-empty string, or set it to None.",
 99        )
100        require(
101            BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore),
102            f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in "
103            f"QAAccuracySemanticRobustnessConfig. Please choose from acceptable values: "
104            f"{BertscoreHelperModelTypes.model_list()}.",
105        )

Configures the QA Accuracy Semantic Robustness evaluation algorithm.

See SemanticRobustnessConfig for the configurable parameters that this config class inherits.

Parameters
  • target_output_delimiter: Target Output can have multiple answers. We expect customer to combine all the possible answers into a single string and use the delimiter to separate them. For instance, if the answers are ["UK", "England"] and the delimiter="", then the target_output should be "UKEngland".
  • model_type_for_bertscore: BERT model type to use for computing BERT score.
QAAccuracySemanticRobustnessConfig( perturbation_type: str = 'butter_finger', num_perturbations: int = 5, butter_finger_perturbation_prob: float = 0.1, random_uppercase_corrupt_proportion: float = 0.1, whitespace_add_prob: float = 0.05, whitespace_remove_prob: float = 0.1, target_output_delimiter: Optional[str] = '<OR>', model_type_for_bertscore: str = 'microsoft/deberta-xlarge-mnli')
target_output_delimiter: Optional[str] = '<OR>'
model_type_for_bertscore: str = 'microsoft/deberta-xlarge-mnli'
class QAAccuracySemanticRobustness(fmeval.eval_algorithms.eval_algorithm.EvalAlgorithmInterface):
108class QAAccuracySemanticRobustness(EvalAlgorithmInterface):
109    """Semantic Robustness evaluation algorithm for QA Accuracy
110
111    This evaluation measures how much QA Accuracy changes as a result of semantic preserving
112    perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text,
113    how much does the quality of the model answer change.
114
115    The output difference is measured by computing the QA Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores
116    on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$
117    where $s$ is the score produced by the original metric (i.e., exact match, quasi-exact match, precision over words, recall over words and F1 over words), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.
118
119    For details on the QA Accuracy metrics, see the QA Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.
120    """
121
122    eval_name = EvalAlgorithm.QA_ACCURACY_SEMANTIC_ROBUSTNESS.value
123
124    def __init__(
125        self, eval_algorithm_config: QAAccuracySemanticRobustnessConfig = QAAccuracySemanticRobustnessConfig()
126    ):
127        """QAAccuracySemanticRobustness initializer.
128
129        :param eval_algorithm_config: QA Accuracy Semantic Robustness evaluation algorithm config.
130        """
131        super().__init__(eval_algorithm_config)
132        self.config = eval_algorithm_config
133        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
134        self.target_output_delimiter = eval_algorithm_config.target_output_delimiter
135        bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
136        self.bertscore_model = bertscore_model
137
138    def _build_pipeline(
139        self, model: ModelRunner, prompt_template: str, bertscore_model: Union[BertscoreHelperModel, ActorHandle]
140    ) -> TransformPipeline:
141        """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`.
142
143        While other evaluation algorithms (ex: QA Accuracy) can configure
144        their TransformPipeline at algorithm initialization, because the QA Accuracy
145        Semantic Robustness algorithm's evaluation logic depends on the ModelRunner
146        and prompt template that are evaluation-specific (i.e. these parameters aren't
147        configured at the algorithm level), the pipeline used by this algorithm is built
148        when `evaluate` or `evaluate_sample` is called.
149
150        :param model: The ModelRunner representing the model under evaluation.
151        :param prompt_template: A template that is used to construct the prompt fed to the model.
152        :param bertscore_model: Either a BertscoreHelperModel instance or a Ray actor handle corresponding
153            to a BertscoreHelperModel (i.e. a shared resource).
154        :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`.
155        """
156        transforms = get_model_outputs_from_perturbed_inputs(
157            self.perturbation_transform,
158            prompt_template,
159            model,
160        )
161        get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_outputs = transforms
162
163        original_scores = QAAccuracyScores(target_output_delimiter=self.target_output_delimiter)
164        perturbed_scores = [
165            QAAccuracyScores(
166                model_output_key=perturbed_output_key,
167                target_output_delimiter=self.target_output_delimiter,
168                output_keys=[
169                    create_output_key(score_name, perturbed_output_key) for score_name in QA_ACCURACY_SCORE_NAMES
170                ],
171            )
172            for perturbed_output_key in get_perturbed_outputs.output_keys
173        ]
174
175        split_transform = SplitWithDelimiter(
176            input_key=DatasetColumns.TARGET_OUTPUT.value.name,
177            output_key=POSSIBLE_TARGETS,
178            target_output_delimiter=self.target_output_delimiter,
179        )
180        bert_score = BertScore(
181            target_output_keys=None,
182            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
183            output_keys=[BERT_SCORE],
184            allow_duplicate_input_keys=True,
185            target_output_keys_provider=POSSIBLE_TARGETS,
186            bertscore_model=bertscore_model,
187        )
188
189        perturbed_bert_score = BertScore(
190            target_output_keys=None,
191            model_output_keys=get_perturbed_outputs.output_keys,
192            output_keys=[
193                create_output_key(BertScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
194            ],
195            allow_duplicate_input_keys=True,
196            target_output_keys_provider=POSSIBLE_TARGETS,
197            bertscore_model=bertscore_model,
198        )
199
200        key_mapping = {
201            original_score_name: (
202                [perturbed_score_transform.output_keys[i] for perturbed_score_transform in perturbed_scores],
203                DELTA_SCORES[i],
204            )
205            for i, original_score_name in enumerate(QA_ACCURACY_SCORE_NAMES)
206        }
207
208        mean_delta_scores = MeanDeltaScores(key_mapping)
209
210        # key mapping + mean of bert scores
211        mean_delta_bert_scores = MeanDeltaScores(
212            {
213                bert_score.output_keys[0]: (perturbed_bert_score.output_keys, DELTA_BERT_SCORE),
214            }
215        )
216
217        transforms = [
218            get_perturbed_inputs,
219            gen_perturbed_prompts,
220            get_perturbed_outputs,
221            original_scores,
222            split_transform,
223            bert_score,
224            TransformPipeline(perturbed_scores),
225            perturbed_bert_score,
226            mean_delta_scores,
227            mean_delta_bert_scores,
228        ]
229        pipeline = TransformPipeline(transforms)
230        return pipeline
231
232    def evaluate_sample(
233        self,
234        model_input: str,
235        target_output: str,
236        model: ModelRunner,
237        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
238    ) -> List[EvalScore]:
239        """Compute question answering accuracy semantic robustness metrics for a single sample.
240
241        A sample is defined as a model input and target output pair.
242
243        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
244        :param target_output: The expected response from the model.
245        :param model: An instance of ModelRunner representing the model under evaluation.
246        :param prompt_template: A template used to compose the prompt from `model_input`.
247        :return: A list of EvalScores.
248        """
249        sample = {
250            DatasetColumns.MODEL_INPUT.value.name: model_input,
251            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
252        }
253        invoke_model = create_model_invocation_pipeline(model, prompt_template)
254        compute_metrics = self._build_pipeline(model, prompt_template, self.bertscore_model)
255        pipeline = TransformPipeline([invoke_model, compute_metrics])
256        output_record = pipeline.execute_record(sample)
257
258        original_scores = [
259            EvalScore(name=score_name, value=output_record[score_name]) for score_name in ORIGINAL_SCORES
260        ]
261        delta_scores = [
262            EvalScore(name=delta_score_name, value=output_record[delta_score_name]) for delta_score_name in DELTA_SCORES
263        ]
264        return original_scores + delta_scores
265
266    def evaluate(
267        self,
268        model: ModelRunner,
269        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
270        prompt_template: Optional[str] = None,
271        num_records: int = 100,
272        save: bool = False,
273        save_strategy: Optional[SaveStrategy] = None,
274    ) -> List[EvalOutput]:
275        """Compute QA accuracy semantic robustness metrics on one or more datasets.
276
277        :param model: An instance of ModelRunner representing the model under evaluation.
278            This is a required argument, as even if the dataset contains model outputs,
279            semantic robustness algorithms rely on invoking a model on perturbed inputs
280            to see how the model outputs from the perturbed inputs differ from the original
281            model outputs.
282        :param dataset_config: Configures a single dataset or list of datasets used for the
283            evaluation. If not provided, this method will run evaluations using all of its
284            supported built-in datasets.
285        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
286            will be used.
287        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
288                            evaluation
289        :param save: If set to true, prompt responses and scores will be saved to a file.
290        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
291            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
292            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
293        :returns: A List of EvalOutput objects.
294        """
295        # Create a shared resource to be used during the evaluation.
296        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
297
298        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
299        eval_outputs = []
300        for dataset_config in dataset_configs:
301            dataset_prompt_template = (
302                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
303            )
304            dataset = get_dataset(dataset_config, num_records)
305            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
306            eval_output = evaluate_dataset(
307                dataset=dataset,
308                pipeline=self._build_pipeline(model, dataset_prompt_template, bertscore_shared_resource),
309                dataset_name=dataset_config.dataset_name,
310                eval_name=self.eval_name,
311                metric_names=ORIGINAL_SCORES + DELTA_SCORES,
312                eval_results_path=get_eval_results_path(),
313                model=model,
314                prompt_template=dataset_prompt_template,
315                agg_method=MEAN,
316                save=save,
317                save_strategy=save_strategy,
318            )
319            eval_outputs.append(eval_output)
320        cleanup_shared_resource(bertscore_shared_resource)
321        return eval_outputs

Semantic Robustness evaluation algorithm for QA Accuracy

This evaluation measures how much QA Accuracy changes as a result of semantic preserving perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text, how much does the quality of the model answer change.

The output difference is measured by computing the QA Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores on average over N (num_perturbations) perturbed inputs: $$ rac{1}{P} \sum_{i=1}^{P} |s - ar{s}_i|,$$ where $s$ is the score produced by the original metric (i.e., exact match, quasi-exact match, precision over words, recall over words and F1 over words), and $ar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.

For details on the QA Accuracy metrics, see the QA Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.

QAAccuracySemanticRobustness( eval_algorithm_config: QAAccuracySemanticRobustnessConfig = QAAccuracySemanticRobustnessConfig(perturbation_type='butter_finger', num_perturbations=5, butter_finger_perturbation_prob=0.1, random_uppercase_corrupt_proportion=0.1, whitespace_add_prob=0.05, whitespace_remove_prob=0.1, target_output_delimiter='<OR>', model_type_for_bertscore='microsoft/deberta-xlarge-mnli'))
124    def __init__(
125        self, eval_algorithm_config: QAAccuracySemanticRobustnessConfig = QAAccuracySemanticRobustnessConfig()
126    ):
127        """QAAccuracySemanticRobustness initializer.
128
129        :param eval_algorithm_config: QA Accuracy Semantic Robustness evaluation algorithm config.
130        """
131        super().__init__(eval_algorithm_config)
132        self.config = eval_algorithm_config
133        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
134        self.target_output_delimiter = eval_algorithm_config.target_output_delimiter
135        bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
136        self.bertscore_model = bertscore_model

QAAccuracySemanticRobustness initializer.

Parameters
  • eval_algorithm_config: QA Accuracy Semantic Robustness evaluation algorithm config.
eval_name = 'qa_accuracy_semantic_robustness'
config
perturbation_transform
target_output_delimiter
bertscore_model
def evaluate_sample( self, model_input: str, target_output: str, model: fmeval.model_runners.model_runner.ModelRunner, prompt_template: str = '$model_input') -> List[fmeval.eval_algorithms.EvalScore]:
232    def evaluate_sample(
233        self,
234        model_input: str,
235        target_output: str,
236        model: ModelRunner,
237        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
238    ) -> List[EvalScore]:
239        """Compute question answering accuracy semantic robustness metrics for a single sample.
240
241        A sample is defined as a model input and target output pair.
242
243        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
244        :param target_output: The expected response from the model.
245        :param model: An instance of ModelRunner representing the model under evaluation.
246        :param prompt_template: A template used to compose the prompt from `model_input`.
247        :return: A list of EvalScores.
248        """
249        sample = {
250            DatasetColumns.MODEL_INPUT.value.name: model_input,
251            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
252        }
253        invoke_model = create_model_invocation_pipeline(model, prompt_template)
254        compute_metrics = self._build_pipeline(model, prompt_template, self.bertscore_model)
255        pipeline = TransformPipeline([invoke_model, compute_metrics])
256        output_record = pipeline.execute_record(sample)
257
258        original_scores = [
259            EvalScore(name=score_name, value=output_record[score_name]) for score_name in ORIGINAL_SCORES
260        ]
261        delta_scores = [
262            EvalScore(name=delta_score_name, value=output_record[delta_score_name]) for delta_score_name in DELTA_SCORES
263        ]
264        return original_scores + delta_scores

Compute question answering accuracy semantic robustness metrics for a single sample.

A sample is defined as a model input and target output pair.

Parameters
  • model_input: Text input, which will be composed into a prompt that gets fed to the model.
  • target_output: The expected response from the model.
  • model: An instance of ModelRunner representing the model under evaluation.
  • prompt_template: A template used to compose the prompt from model_input.
Returns

A list of EvalScores.

def evaluate( self, model: fmeval.model_runners.model_runner.ModelRunner, dataset_config: Union[fmeval.data_loaders.data_config.DataConfig, List[fmeval.data_loaders.data_config.DataConfig], NoneType] = None, prompt_template: Optional[str] = None, num_records: int = 100, save: bool = False, save_strategy: Optional[fmeval.eval_algorithms.save_strategy.SaveStrategy] = None) -> List[fmeval.eval_algorithms.EvalOutput]:
266    def evaluate(
267        self,
268        model: ModelRunner,
269        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
270        prompt_template: Optional[str] = None,
271        num_records: int = 100,
272        save: bool = False,
273        save_strategy: Optional[SaveStrategy] = None,
274    ) -> List[EvalOutput]:
275        """Compute QA accuracy semantic robustness metrics on one or more datasets.
276
277        :param model: An instance of ModelRunner representing the model under evaluation.
278            This is a required argument, as even if the dataset contains model outputs,
279            semantic robustness algorithms rely on invoking a model on perturbed inputs
280            to see how the model outputs from the perturbed inputs differ from the original
281            model outputs.
282        :param dataset_config: Configures a single dataset or list of datasets used for the
283            evaluation. If not provided, this method will run evaluations using all of its
284            supported built-in datasets.
285        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
286            will be used.
287        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
288                            evaluation
289        :param save: If set to true, prompt responses and scores will be saved to a file.
290        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
291            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
292            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
293        :returns: A List of EvalOutput objects.
294        """
295        # Create a shared resource to be used during the evaluation.
296        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
297
298        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
299        eval_outputs = []
300        for dataset_config in dataset_configs:
301            dataset_prompt_template = (
302                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
303            )
304            dataset = get_dataset(dataset_config, num_records)
305            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
306            eval_output = evaluate_dataset(
307                dataset=dataset,
308                pipeline=self._build_pipeline(model, dataset_prompt_template, bertscore_shared_resource),
309                dataset_name=dataset_config.dataset_name,
310                eval_name=self.eval_name,
311                metric_names=ORIGINAL_SCORES + DELTA_SCORES,
312                eval_results_path=get_eval_results_path(),
313                model=model,
314                prompt_template=dataset_prompt_template,
315                agg_method=MEAN,
316                save=save,
317                save_strategy=save_strategy,
318            )
319            eval_outputs.append(eval_output)
320        cleanup_shared_resource(bertscore_shared_resource)
321        return eval_outputs

Compute QA accuracy semantic robustness metrics on one or more datasets.

Parameters
  • model: An instance of ModelRunner representing the model under evaluation. This is a required argument, as even if the dataset contains model outputs, semantic robustness algorithms rely on invoking a model on perturbed inputs to see how the model outputs from the perturbed inputs differ from the original model outputs.
  • dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
  • prompt_template: A template which can be used to generate prompts, optional, if not provided defaults will be used.
  • num_records: The number of records to be sampled randomly from the input dataset to perform the evaluation
  • save: If set to true, prompt responses and scores will be saved to a file.
  • save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. If that environment variable is also not configured, it will be saved to the default path /tmp/eval_results/. :returns: A List of EvalOutput objects.