fmeval.eval_algorithms.summarization_accuracy_semantic_robustness

  1import logging
  2
  3from dataclasses import dataclass
  4from typing import List, Optional, Union
  5
  6from ray.actor import ActorHandle
  7
  8from fmeval.constants import (
  9    DatasetColumns,
 10    PREFIX_FOR_DELTA_SCORES,
 11    BERTSCORE_DEFAULT_MODEL,
 12    MEAN,
 13)
 14from fmeval.data_loaders.data_config import DataConfig
 15from fmeval.data_loaders.util import get_dataset
 16from fmeval.eval_algorithms import (
 17    EvalAlgorithm,
 18    EvalScore,
 19    EvalOutput,
 20    DEFAULT_PROMPT_TEMPLATE,
 21    get_default_prompt_template,
 22)
 23from fmeval.eval_algorithms.common import evaluate_dataset
 24from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface
 25from fmeval.eval_algorithms.save_strategy import SaveStrategy
 26from fmeval.eval_algorithms.semantic_robustness_utils import (
 27    SemanticRobustnessConfig,
 28    get_perturbation_transform,
 29    get_model_outputs_from_perturbed_inputs,
 30)
 31from fmeval.transforms.semantic_robustness_metrics import MeanDeltaScores
 32from fmeval.eval_algorithms.summarization_accuracy import (
 33    SummarizationAccuracy,
 34    ROUGE_TYPES,
 35    ROUGE_2,
 36    ROUGE_SCORE,
 37    METEOR_SCORE,
 38    BERT_SCORE,
 39)
 40from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModel, BertscoreHelperModelTypes
 41from fmeval.eval_algorithms.util import (
 42    get_dataset_configs,
 43    create_model_invocation_pipeline,
 44    validate_dataset,
 45)
 46from fmeval.transforms.summarization_accuracy_metrics import MeteorScore, RougeScore, BertScore
 47from fmeval.transforms.transform_pipeline import TransformPipeline
 48from fmeval.transforms.util import create_output_key
 49from fmeval.model_runners.model_runner import ModelRunner
 50from fmeval.util import create_shared_resource, get_eval_results_path, require, cleanup_shared_resource
 51
 52logger = logging.getLogger(__name__)
 53
 54DELTA_ROUGE_SCORE = PREFIX_FOR_DELTA_SCORES + ROUGE_SCORE
 55DELTA_METEOR_SCORE = PREFIX_FOR_DELTA_SCORES + METEOR_SCORE
 56DELTA_BERT_SCORE = PREFIX_FOR_DELTA_SCORES + BERT_SCORE
 57DELTA_SCORES = [DELTA_METEOR_SCORE, DELTA_ROUGE_SCORE, DELTA_BERT_SCORE]
 58ORIGINAL_SCORES = [METEOR_SCORE, ROUGE_SCORE, BERT_SCORE]
 59
 60
 61@dataclass(frozen=True)
 62class SummarizationAccuracySemanticRobustnessConfig(SemanticRobustnessConfig):
 63    """Configures the summarization accuracy semantic robustness evaluation algorithm.
 64
 65    See SemanticRobustnessConfig for the configurable parameters that this config class inherits.
 66
 67    :param rouge_type: ROUGE metric type.
 68    :param use_stemmer_for_rouge: Whether to use stemmer when computing ROUGE metric.
 69    :param model_type_for_bertscore: BERT model type to use for computing BERT score.
 70    """
 71
 72    rouge_type: str = ROUGE_2
 73    use_stemmer_for_rouge: bool = True
 74    model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL
 75
 76    def __post_init__(self):
 77        super().__post_init__()
 78        require(
 79            self.rouge_type in ROUGE_TYPES,
 80            f"Invalid rouge_type: {self.rouge_type} requested in SummarizationAccuracySemanticRobustnessConfig. "
 81            f"Please choose from acceptable values: {ROUGE_TYPES}.",
 82        )
 83        require(
 84            BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore),
 85            f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in "
 86            f"SummarizationAccuracySemanticRobustnessConfig. Please choose from acceptable values: {BertscoreHelperModelTypes.model_list()}.",
 87        )
 88
 89
 90class SummarizationAccuracySemanticRobustness(EvalAlgorithmInterface):
 91    """Semantic Robustness evaluation algorithm for Summarization Accuracy
 92
 93    This evaluation measures how much Summarization Accuracy changes as a result of semantic preserving
 94    perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text,
 95    how much does the quality of the model summary change.
 96
 97    The output difference is measured by computing the Summarization Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores
 98    on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$
 99    where $s$ is the score produced by the original metric (i.e., ROUGE, METEOR and BERTScore), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.
100
101    For details on the Summarization Accuracy metrics, see the Summarization Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.
102    """
103
104    eval_name = EvalAlgorithm.SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS.value
105
106    def __init__(
107        self,
108        eval_algorithm_config: SummarizationAccuracySemanticRobustnessConfig = SummarizationAccuracySemanticRobustnessConfig(),
109    ):
110        """SummarizationAccuracySemanticRobustness initializer.
111
112        :param eval_algorithm_config: Summarization accuracy semantic robustness evaluation algorithm config.
113        """
114        super().__init__(eval_algorithm_config)
115        self.config = eval_algorithm_config
116        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
117        bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
118        self.bertscore_model = bertscore_model
119
120    def _build_pipeline(
121        self,
122        model: ModelRunner,
123        prompt_template: str,
124        bertscore_model: Union[BertscoreHelperModel, ActorHandle],
125    ) -> TransformPipeline:
126        """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`.
127
128        While other evaluation algorithms (ex: Summarization Accuracy) can configure
129        their TransformPipeline at algorithm initialization, because the Summarization Accuracy
130        Semantic Robustness algorithm's evaluation logic depends on the ModelRunner
131        and prompt template that are evaluation-specific (i.e. these parameters aren't
132        configured at the algorithm level), the pipeline used by this algorithm is built
133        when `evaluate` or `evaluate_sample` is called.
134
135        :param model: The ModelRunner representing the model under evaluation.
136        :param prompt_template: A template that is used to construct the prompt fed to the model.
137        :param bertscore_model: Either a BertscoreHelperModel instance or a Ray actor handle corresponding
138            to a BertscoreHelperModel (i.e. a shared resource).
139        :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`.
140        """
141        transforms = get_model_outputs_from_perturbed_inputs(
142            self.perturbation_transform,
143            prompt_template,
144            model,
145        )
146        get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_responses = transforms
147
148        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
149            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
150            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
151            meteor_keys=[METEOR_SCORE],
152            rouge_keys=[ROUGE_SCORE],
153            bertscore_keys=[BERT_SCORE],
154            rouge_type=self.config.rouge_type,
155            use_stemmer_for_rouge=self.config.use_stemmer_for_rouge,
156            bertscore_model=bertscore_model,
157        )
158
159        perturbed_meteor, perturbed_rouge, perturbed_bert_score = SummarizationAccuracy._create_transforms(
160            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
161            model_output_keys=get_perturbed_responses.output_keys,
162            meteor_keys=[
163                create_output_key(MeteorScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
164            ],
165            rouge_keys=[
166                create_output_key(RougeScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
167            ],
168            bertscore_keys=[
169                create_output_key(BertScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
170            ],
171            rouge_type=self.config.rouge_type,
172            use_stemmer_for_rouge=self.config.use_stemmer_for_rouge,
173            bertscore_model=bertscore_model,
174        )
175
176        delta_meteor_key = DELTA_METEOR_SCORE
177        delta_rouge_key = DELTA_ROUGE_SCORE
178        delta_bert_key = DELTA_BERT_SCORE
179        mean_delta_scores = MeanDeltaScores(
180            {
181                meteor_score.output_keys[0]: (perturbed_meteor.output_keys, delta_meteor_key),
182                rouge_score.output_keys[0]: (perturbed_rouge.output_keys, delta_rouge_key),
183                bert_score.output_keys[0]: (perturbed_bert_score.output_keys, delta_bert_key),
184            }
185        )
186
187        transforms = [
188            get_perturbed_inputs,
189            gen_perturbed_prompts,
190            get_perturbed_responses,
191            meteor_score,
192            rouge_score,
193            bert_score,
194            perturbed_meteor,
195            perturbed_rouge,
196            perturbed_bert_score,
197            mean_delta_scores,
198        ]
199        pipeline = TransformPipeline(transforms)
200        return pipeline
201
202    def evaluate_sample(
203        self,
204        model_input: str,
205        target_output: str,
206        model: ModelRunner,
207        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
208    ) -> List[EvalScore]:
209        """Compute summarization accuracy semantic robustness metrics for a single sample.
210
211        A sample is defined as a model input and model output pair.
212
213        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
214        :param target_output: The expected response from the model.
215        :param model: An instance of ModelRunner representing the model under evaluation.
216        :param prompt_template: A template used to compose the prompt from `model_input`.
217        :return: A list of EvalScores.
218        """
219        sample = {
220            DatasetColumns.MODEL_INPUT.value.name: model_input,
221            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
222        }
223        invoke_model = create_model_invocation_pipeline(model, prompt_template)
224        compute_metrics = self._build_pipeline(model, prompt_template, self.bertscore_model)
225        pipeline = TransformPipeline([invoke_model, compute_metrics])
226        output_record = pipeline.execute_record(sample)
227
228        original_scores = [
229            EvalScore(name=score_name, value=output_record[score_name]) for score_name in ORIGINAL_SCORES
230        ]
231        delta_scores = [
232            EvalScore(name=delta_score_name, value=output_record[delta_score_name]) for delta_score_name in DELTA_SCORES
233        ]
234        return original_scores + delta_scores
235
236    def evaluate(
237        self,
238        model: ModelRunner,
239        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
240        prompt_template: Optional[str] = None,
241        num_records: int = 100,
242        save: bool = False,
243        save_strategy: Optional[SaveStrategy] = None,
244    ) -> List[EvalOutput]:
245        """
246        Semantic Robustness evaluate.
247
248        :param model: An instance of ModelRunner representing the model under evaluation.
249            This is a required argument, as even if the dataset contains model outputs,
250            semantic robustness algorithms rely on invoking a model on perturbed inputs
251            to see how the model outputs from the perturbed inputs differ from the original
252            model outputs.
253        :param dataset_config: Configures a single dataset or list of datasets used for the
254            evaluation. If not provided, this method will run evaluations using all of its
255            supported built-in datasets.
256        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
257            will be used.
258        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
259                            evaluation
260        :param save: If set to true, prompt responses and scores will be saved to a file.
261        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
262            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
263            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
264        :return: List of EvalOutput objects.
265        """
266        # Create a shared resource to be used during the evaluation.
267        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
268
269        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
270        eval_outputs = []
271        for dataset_config in dataset_configs:
272            dataset_prompt_template = (
273                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
274            )
275            dataset = get_dataset(dataset_config, num_records)
276            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
277            eval_output = evaluate_dataset(
278                dataset=dataset,
279                pipeline=self._build_pipeline(model, dataset_prompt_template, bertscore_shared_resource),
280                dataset_name=dataset_config.dataset_name,
281                eval_name=self.eval_name,
282                metric_names=ORIGINAL_SCORES + DELTA_SCORES,
283                eval_results_path=get_eval_results_path(),
284                model=model,
285                prompt_template=dataset_prompt_template,
286                agg_method=MEAN,
287                save=save,
288                save_strategy=save_strategy,
289            )
290            eval_outputs.append(eval_output)
291
292        cleanup_shared_resource(bertscore_shared_resource)
293        return eval_outputs
DELTA_ROUGE_SCORE = 'delta_rouge'
DELTA_METEOR_SCORE = 'delta_meteor'
DELTA_BERT_SCORE = 'delta_bertscore'
DELTA_SCORES = ['delta_meteor', 'delta_rouge', 'delta_bertscore']
ORIGINAL_SCORES = ['meteor', 'rouge', 'bertscore']
@dataclass(frozen=True)
class SummarizationAccuracySemanticRobustnessConfig(fmeval.eval_algorithms.semantic_robustness_utils.SemanticRobustnessConfig):
62@dataclass(frozen=True)
63class SummarizationAccuracySemanticRobustnessConfig(SemanticRobustnessConfig):
64    """Configures the summarization accuracy semantic robustness evaluation algorithm.
65
66    See SemanticRobustnessConfig for the configurable parameters that this config class inherits.
67
68    :param rouge_type: ROUGE metric type.
69    :param use_stemmer_for_rouge: Whether to use stemmer when computing ROUGE metric.
70    :param model_type_for_bertscore: BERT model type to use for computing BERT score.
71    """
72
73    rouge_type: str = ROUGE_2
74    use_stemmer_for_rouge: bool = True
75    model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL
76
77    def __post_init__(self):
78        super().__post_init__()
79        require(
80            self.rouge_type in ROUGE_TYPES,
81            f"Invalid rouge_type: {self.rouge_type} requested in SummarizationAccuracySemanticRobustnessConfig. "
82            f"Please choose from acceptable values: {ROUGE_TYPES}.",
83        )
84        require(
85            BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore),
86            f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in "
87            f"SummarizationAccuracySemanticRobustnessConfig. Please choose from acceptable values: {BertscoreHelperModelTypes.model_list()}.",
88        )

Configures the summarization accuracy semantic robustness evaluation algorithm.

See SemanticRobustnessConfig for the configurable parameters that this config class inherits.

Parameters
  • rouge_type: ROUGE metric type.
  • use_stemmer_for_rouge: Whether to use stemmer when computing ROUGE metric.
  • model_type_for_bertscore: BERT model type to use for computing BERT score.
SummarizationAccuracySemanticRobustnessConfig( perturbation_type: str = 'butter_finger', num_perturbations: int = 5, butter_finger_perturbation_prob: float = 0.1, random_uppercase_corrupt_proportion: float = 0.1, whitespace_add_prob: float = 0.05, whitespace_remove_prob: float = 0.1, rouge_type: str = 'rouge2', use_stemmer_for_rouge: bool = True, model_type_for_bertscore: str = 'microsoft/deberta-xlarge-mnli')
rouge_type: str = 'rouge2'
use_stemmer_for_rouge: bool = True
model_type_for_bertscore: str = 'microsoft/deberta-xlarge-mnli'
class SummarizationAccuracySemanticRobustness(fmeval.eval_algorithms.eval_algorithm.EvalAlgorithmInterface):
 91class SummarizationAccuracySemanticRobustness(EvalAlgorithmInterface):
 92    """Semantic Robustness evaluation algorithm for Summarization Accuracy
 93
 94    This evaluation measures how much Summarization Accuracy changes as a result of semantic preserving
 95    perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text,
 96    how much does the quality of the model summary change.
 97
 98    The output difference is measured by computing the Summarization Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores
 99    on average over N (`num_perturbations`) perturbed inputs: $$ \frac{1}{P} \sum_{i=1}^{P} |s - \bar{s}_i|,$$
100    where $s$ is the score produced by the original metric (i.e., ROUGE, METEOR and BERTScore), and $\bar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.
101
102    For details on the Summarization Accuracy metrics, see the Summarization Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.
103    """
104
105    eval_name = EvalAlgorithm.SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS.value
106
107    def __init__(
108        self,
109        eval_algorithm_config: SummarizationAccuracySemanticRobustnessConfig = SummarizationAccuracySemanticRobustnessConfig(),
110    ):
111        """SummarizationAccuracySemanticRobustness initializer.
112
113        :param eval_algorithm_config: Summarization accuracy semantic robustness evaluation algorithm config.
114        """
115        super().__init__(eval_algorithm_config)
116        self.config = eval_algorithm_config
117        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
118        bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
119        self.bertscore_model = bertscore_model
120
121    def _build_pipeline(
122        self,
123        model: ModelRunner,
124        prompt_template: str,
125        bertscore_model: Union[BertscoreHelperModel, ActorHandle],
126    ) -> TransformPipeline:
127        """Build the TransformPipeline to be used by `evaluate` and `evaluate_sample`.
128
129        While other evaluation algorithms (ex: Summarization Accuracy) can configure
130        their TransformPipeline at algorithm initialization, because the Summarization Accuracy
131        Semantic Robustness algorithm's evaluation logic depends on the ModelRunner
132        and prompt template that are evaluation-specific (i.e. these parameters aren't
133        configured at the algorithm level), the pipeline used by this algorithm is built
134        when `evaluate` or `evaluate_sample` is called.
135
136        :param model: The ModelRunner representing the model under evaluation.
137        :param prompt_template: A template that is used to construct the prompt fed to the model.
138        :param bertscore_model: Either a BertscoreHelperModel instance or a Ray actor handle corresponding
139            to a BertscoreHelperModel (i.e. a shared resource).
140        :returns: A TransformPipeline that can be used by either `evaluate_sample` or `evaluate`.
141        """
142        transforms = get_model_outputs_from_perturbed_inputs(
143            self.perturbation_transform,
144            prompt_template,
145            model,
146        )
147        get_perturbed_inputs, gen_perturbed_prompts, get_perturbed_responses = transforms
148
149        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
150            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
151            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
152            meteor_keys=[METEOR_SCORE],
153            rouge_keys=[ROUGE_SCORE],
154            bertscore_keys=[BERT_SCORE],
155            rouge_type=self.config.rouge_type,
156            use_stemmer_for_rouge=self.config.use_stemmer_for_rouge,
157            bertscore_model=bertscore_model,
158        )
159
160        perturbed_meteor, perturbed_rouge, perturbed_bert_score = SummarizationAccuracy._create_transforms(
161            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
162            model_output_keys=get_perturbed_responses.output_keys,
163            meteor_keys=[
164                create_output_key(MeteorScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
165            ],
166            rouge_keys=[
167                create_output_key(RougeScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
168            ],
169            bertscore_keys=[
170                create_output_key(BertScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
171            ],
172            rouge_type=self.config.rouge_type,
173            use_stemmer_for_rouge=self.config.use_stemmer_for_rouge,
174            bertscore_model=bertscore_model,
175        )
176
177        delta_meteor_key = DELTA_METEOR_SCORE
178        delta_rouge_key = DELTA_ROUGE_SCORE
179        delta_bert_key = DELTA_BERT_SCORE
180        mean_delta_scores = MeanDeltaScores(
181            {
182                meteor_score.output_keys[0]: (perturbed_meteor.output_keys, delta_meteor_key),
183                rouge_score.output_keys[0]: (perturbed_rouge.output_keys, delta_rouge_key),
184                bert_score.output_keys[0]: (perturbed_bert_score.output_keys, delta_bert_key),
185            }
186        )
187
188        transforms = [
189            get_perturbed_inputs,
190            gen_perturbed_prompts,
191            get_perturbed_responses,
192            meteor_score,
193            rouge_score,
194            bert_score,
195            perturbed_meteor,
196            perturbed_rouge,
197            perturbed_bert_score,
198            mean_delta_scores,
199        ]
200        pipeline = TransformPipeline(transforms)
201        return pipeline
202
203    def evaluate_sample(
204        self,
205        model_input: str,
206        target_output: str,
207        model: ModelRunner,
208        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
209    ) -> List[EvalScore]:
210        """Compute summarization accuracy semantic robustness metrics for a single sample.
211
212        A sample is defined as a model input and model output pair.
213
214        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
215        :param target_output: The expected response from the model.
216        :param model: An instance of ModelRunner representing the model under evaluation.
217        :param prompt_template: A template used to compose the prompt from `model_input`.
218        :return: A list of EvalScores.
219        """
220        sample = {
221            DatasetColumns.MODEL_INPUT.value.name: model_input,
222            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
223        }
224        invoke_model = create_model_invocation_pipeline(model, prompt_template)
225        compute_metrics = self._build_pipeline(model, prompt_template, self.bertscore_model)
226        pipeline = TransformPipeline([invoke_model, compute_metrics])
227        output_record = pipeline.execute_record(sample)
228
229        original_scores = [
230            EvalScore(name=score_name, value=output_record[score_name]) for score_name in ORIGINAL_SCORES
231        ]
232        delta_scores = [
233            EvalScore(name=delta_score_name, value=output_record[delta_score_name]) for delta_score_name in DELTA_SCORES
234        ]
235        return original_scores + delta_scores
236
237    def evaluate(
238        self,
239        model: ModelRunner,
240        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
241        prompt_template: Optional[str] = None,
242        num_records: int = 100,
243        save: bool = False,
244        save_strategy: Optional[SaveStrategy] = None,
245    ) -> List[EvalOutput]:
246        """
247        Semantic Robustness evaluate.
248
249        :param model: An instance of ModelRunner representing the model under evaluation.
250            This is a required argument, as even if the dataset contains model outputs,
251            semantic robustness algorithms rely on invoking a model on perturbed inputs
252            to see how the model outputs from the perturbed inputs differ from the original
253            model outputs.
254        :param dataset_config: Configures a single dataset or list of datasets used for the
255            evaluation. If not provided, this method will run evaluations using all of its
256            supported built-in datasets.
257        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
258            will be used.
259        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
260                            evaluation
261        :param save: If set to true, prompt responses and scores will be saved to a file.
262        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
263            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
264            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
265        :return: List of EvalOutput objects.
266        """
267        # Create a shared resource to be used during the evaluation.
268        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
269
270        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
271        eval_outputs = []
272        for dataset_config in dataset_configs:
273            dataset_prompt_template = (
274                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
275            )
276            dataset = get_dataset(dataset_config, num_records)
277            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
278            eval_output = evaluate_dataset(
279                dataset=dataset,
280                pipeline=self._build_pipeline(model, dataset_prompt_template, bertscore_shared_resource),
281                dataset_name=dataset_config.dataset_name,
282                eval_name=self.eval_name,
283                metric_names=ORIGINAL_SCORES + DELTA_SCORES,
284                eval_results_path=get_eval_results_path(),
285                model=model,
286                prompt_template=dataset_prompt_template,
287                agg_method=MEAN,
288                save=save,
289                save_strategy=save_strategy,
290            )
291            eval_outputs.append(eval_output)
292
293        cleanup_shared_resource(bertscore_shared_resource)
294        return eval_outputs

Semantic Robustness evaluation algorithm for Summarization Accuracy

This evaluation measures how much Summarization Accuracy changes as a result of semantic preserving perturbations on the input. For example, if we apply the whitespace perturbation (adding extra whitepaces at random) to the input text, how much does the quality of the model summary change.

The output difference is measured by computing the Summarization Accuracy metrics before after perturbing the inputs. We report the absolute value of the difference in scores on average over N (num_perturbations) perturbed inputs: $$ rac{1}{P} \sum_{i=1}^{P} |s - ar{s}_i|,$$ where $s$ is the score produced by the original metric (i.e., ROUGE, METEOR and BERTScore), and $ar{s_i}$ is the metric evaluated after the i-th perturbation has been applied.

For details on the Summarization Accuracy metrics, see the Summarization Accuracy evaluation. For details on perturbations, see the GeneralSemanticRobustness evaluation.

SummarizationAccuracySemanticRobustness( eval_algorithm_config: SummarizationAccuracySemanticRobustnessConfig = SummarizationAccuracySemanticRobustnessConfig(perturbation_type='butter_finger', num_perturbations=5, butter_finger_perturbation_prob=0.1, random_uppercase_corrupt_proportion=0.1, whitespace_add_prob=0.05, whitespace_remove_prob=0.1, rouge_type='rouge2', use_stemmer_for_rouge=True, model_type_for_bertscore='microsoft/deberta-xlarge-mnli'))
107    def __init__(
108        self,
109        eval_algorithm_config: SummarizationAccuracySemanticRobustnessConfig = SummarizationAccuracySemanticRobustnessConfig(),
110    ):
111        """SummarizationAccuracySemanticRobustness initializer.
112
113        :param eval_algorithm_config: Summarization accuracy semantic robustness evaluation algorithm config.
114        """
115        super().__init__(eval_algorithm_config)
116        self.config = eval_algorithm_config
117        self.perturbation_transform = get_perturbation_transform(eval_algorithm_config)
118        bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
119        self.bertscore_model = bertscore_model

SummarizationAccuracySemanticRobustness initializer.

Parameters
  • eval_algorithm_config: Summarization accuracy semantic robustness evaluation algorithm config.
eval_name = 'summarization_accuracy_semantic_robustness'
config
perturbation_transform
bertscore_model
def evaluate_sample( self, model_input: str, target_output: str, model: fmeval.model_runners.model_runner.ModelRunner, prompt_template: str = '$model_input') -> List[fmeval.eval_algorithms.EvalScore]:
203    def evaluate_sample(
204        self,
205        model_input: str,
206        target_output: str,
207        model: ModelRunner,
208        prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
209    ) -> List[EvalScore]:
210        """Compute summarization accuracy semantic robustness metrics for a single sample.
211
212        A sample is defined as a model input and model output pair.
213
214        :param model_input: Text input, which will be composed into a prompt that gets fed to the model.
215        :param target_output: The expected response from the model.
216        :param model: An instance of ModelRunner representing the model under evaluation.
217        :param prompt_template: A template used to compose the prompt from `model_input`.
218        :return: A list of EvalScores.
219        """
220        sample = {
221            DatasetColumns.MODEL_INPUT.value.name: model_input,
222            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
223        }
224        invoke_model = create_model_invocation_pipeline(model, prompt_template)
225        compute_metrics = self._build_pipeline(model, prompt_template, self.bertscore_model)
226        pipeline = TransformPipeline([invoke_model, compute_metrics])
227        output_record = pipeline.execute_record(sample)
228
229        original_scores = [
230            EvalScore(name=score_name, value=output_record[score_name]) for score_name in ORIGINAL_SCORES
231        ]
232        delta_scores = [
233            EvalScore(name=delta_score_name, value=output_record[delta_score_name]) for delta_score_name in DELTA_SCORES
234        ]
235        return original_scores + delta_scores

Compute summarization accuracy semantic robustness metrics for a single sample.

A sample is defined as a model input and model output pair.

Parameters
  • model_input: Text input, which will be composed into a prompt that gets fed to the model.
  • target_output: The expected response from the model.
  • model: An instance of ModelRunner representing the model under evaluation.
  • prompt_template: A template used to compose the prompt from model_input.
Returns

A list of EvalScores.

def evaluate( self, model: fmeval.model_runners.model_runner.ModelRunner, dataset_config: Union[fmeval.data_loaders.data_config.DataConfig, List[fmeval.data_loaders.data_config.DataConfig], NoneType] = None, prompt_template: Optional[str] = None, num_records: int = 100, save: bool = False, save_strategy: Optional[fmeval.eval_algorithms.save_strategy.SaveStrategy] = None) -> List[fmeval.eval_algorithms.EvalOutput]:
237    def evaluate(
238        self,
239        model: ModelRunner,
240        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
241        prompt_template: Optional[str] = None,
242        num_records: int = 100,
243        save: bool = False,
244        save_strategy: Optional[SaveStrategy] = None,
245    ) -> List[EvalOutput]:
246        """
247        Semantic Robustness evaluate.
248
249        :param model: An instance of ModelRunner representing the model under evaluation.
250            This is a required argument, as even if the dataset contains model outputs,
251            semantic robustness algorithms rely on invoking a model on perturbed inputs
252            to see how the model outputs from the perturbed inputs differ from the original
253            model outputs.
254        :param dataset_config: Configures a single dataset or list of datasets used for the
255            evaluation. If not provided, this method will run evaluations using all of its
256            supported built-in datasets.
257        :param prompt_template: A template which can be used to generate prompts, optional, if not provided defaults
258            will be used.
259        :param num_records: The number of records to be sampled randomly from the input dataset to perform the
260                            evaluation
261        :param save: If set to true, prompt responses and scores will be saved to a file.
262        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
263            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
264            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
265        :return: List of EvalOutput objects.
266        """
267        # Create a shared resource to be used during the evaluation.
268        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
269
270        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
271        eval_outputs = []
272        for dataset_config in dataset_configs:
273            dataset_prompt_template = (
274                get_default_prompt_template(dataset_config.dataset_name) if not prompt_template else prompt_template
275            )
276            dataset = get_dataset(dataset_config, num_records)
277            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
278            eval_output = evaluate_dataset(
279                dataset=dataset,
280                pipeline=self._build_pipeline(model, dataset_prompt_template, bertscore_shared_resource),
281                dataset_name=dataset_config.dataset_name,
282                eval_name=self.eval_name,
283                metric_names=ORIGINAL_SCORES + DELTA_SCORES,
284                eval_results_path=get_eval_results_path(),
285                model=model,
286                prompt_template=dataset_prompt_template,
287                agg_method=MEAN,
288                save=save,
289                save_strategy=save_strategy,
290            )
291            eval_outputs.append(eval_output)
292
293        cleanup_shared_resource(bertscore_shared_resource)
294        return eval_outputs

Semantic Robustness evaluate.

Parameters
  • model: An instance of ModelRunner representing the model under evaluation. This is a required argument, as even if the dataset contains model outputs, semantic robustness algorithms rely on invoking a model on perturbed inputs to see how the model outputs from the perturbed inputs differ from the original model outputs.
  • dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
  • prompt_template: A template which can be used to generate prompts, optional, if not provided defaults will be used.
  • num_records: The number of records to be sampled randomly from the input dataset to perform the evaluation
  • save: If set to true, prompt responses and scores will be saved to a file.
  • save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. If that environment variable is also not configured, it will be saved to the default path /tmp/eval_results/.
Returns

List of EvalOutput objects.