fmeval.eval_algorithms.summarization_accuracy

  1import logging
  2from dataclasses import dataclass
  3from typing import List, Optional, Tuple, Union
  4from ray.actor import ActorHandle
  5
  6from fmeval.data_loaders.util import get_dataset
  7from fmeval.eval_algorithms import EvalAlgorithm, EvalOutput, EvalScore
  8from fmeval.eval_algorithms.common import evaluate_dataset
  9from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmInterface, EvalAlgorithmConfig
 10from fmeval.eval_algorithms.save_strategy import SaveStrategy
 11from fmeval.eval_algorithms.util import get_dataset_configs, validate_dataset
 12from fmeval.util import (
 13    assert_condition,
 14    require,
 15    create_shared_resource,
 16    get_eval_results_path,
 17    cleanup_shared_resource,
 18)
 19from fmeval.constants import BERTSCORE_DEFAULT_MODEL, DatasetColumns, MEAN
 20from fmeval.transforms.transform_pipeline import TransformPipeline
 21from fmeval.data_loaders.data_config import DataConfig
 22from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModelTypes, BertscoreHelperModel
 23from fmeval.model_runners.model_runner import ModelRunner
 24from fmeval.transforms.summarization_accuracy_metrics import (
 25    MeteorScore,
 26    RougeScore,
 27    BertScore,
 28    METEOR_SCORE,
 29    ROUGE_SCORE,
 30    BERT_SCORE,
 31    ROUGE_2,
 32    ROUGE_TYPES,
 33)
 34
 35
 36METRIC_NAMES = [METEOR_SCORE, ROUGE_SCORE, BERT_SCORE]
 37
 38logger = logging.getLogger(__name__)
 39
 40
 41@dataclass(frozen=True)
 42class SummarizationAccuracyConfig(EvalAlgorithmConfig):
 43    """Configures the summarization accuracy evaluation algorithm.
 44
 45    :param rouge_type: ROUGE metric type.
 46    :param use_stemmer_for_rouge: Whether to use stemmer when computing ROUGE metric.
 47    :param model_type_for_bertscore: BERT model type to use for computing BERT score.
 48    """
 49
 50    rouge_type: str = ROUGE_2
 51    use_stemmer_for_rouge: bool = True
 52    model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL
 53
 54    def __post_init__(self):
 55        require(
 56            self.rouge_type in ROUGE_TYPES,
 57            f"Invalid rouge_type: {self.rouge_type} requested in SummarizationAccuracyConfig. "
 58            f"Please choose from acceptable values: {ROUGE_TYPES}.",
 59        )
 60        require(
 61            BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore),
 62            f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in "
 63            f"SummarizationAccuracyConfig. Please choose from acceptable values: "
 64            f"{BertscoreHelperModelTypes.model_list()}.",
 65        )
 66
 67
 68class SummarizationAccuracy(EvalAlgorithmInterface):
 69    """Summarization Accuracy evaluation algorithm.
 70
 71    This evaluation measures how accurately a model can summarize text. By default, we carry out this evaluation by benchmarking on two built-in datasets containing pairs of input text and target summary. The model summaries are then compared to the target summaries using three built-in metrics that measure how similar the summaries are in different ways:
 72
 73    1. ROUGE-N: ROUGE scores are a class of metrics that compute N-gram word overlaps between reference and model summary.  The metrics are case insensitive and the values are in the range of 0 (no match) to 1 (perfect match). It has the following configurable parameters which can be set in the `SummarizationAccuracyConfig`:
 74        * N: the length of N-grams to be matched. The three supported values are
 75            *  N=1 matches single words (unigrams)
 76            *  N=2 (default) matches word pairs (bigrams)
 77            *  N=L matches the longest common subsequence.  For computing the longest common subsequence, order is accounted for, but consecutiveness is discounted. E.g., for model summary = "It is autumn"  and  reference = "It is once again autumn" we have that LCS(prediction, reference)=3.
 78        * use_stemmer: If True (default), uses [Porter stemmer](https://www.cs.toronto.edu/~frank/csc2501/Readings/R2_Porter/Porter-1980.pdf) to strip word suffices. For example, "raining" → "rain".
 79    To obtain ROUGE-N, N-gram precision and recall are computed. Those are then aggregated into the final score:
 80    ROUGE-N = 2 * (precision_N * recall_N) / (precision_N + recall_N).
 81
 82    2. [Meteor](https://aclanthology.org/W05-0909.pdf) is similar to ROUGE-1, but includes stemming (with Porter stemmer) and synonym matching via synonym lists (e.g. “fall” → “autumn”).  The words that are matched by the Meteor score are marked in yellow above. Because Meteor can match synonyms, it is more flexible to paraphrasing than ROUGE.
 83    2. [BERTScore](https://arxiv.org/pdf/1904.09675.pdf) uses a second ML model (from the BERT family) to compute sentence embeddings and compare their cosine similarity. This score may account for additional linguistic flexibility over ROUGE and METEOR since semantically similar sentences should be embedded closer to each other.
 84
 85    Parameters which can be set in the `SummarizationAccuracyConfig` are:
 86    * model_name: Name of the model to be used for scoring, choose one of "microsoft/deberta-xlarge-mnli"  (default) and “roberta-large-mnli" .
 87
 88
 89    """
 90
 91    eval_name = EvalAlgorithm.SUMMARIZATION_ACCURACY.value
 92
 93    def __init__(self, eval_algorithm_config: SummarizationAccuracyConfig = SummarizationAccuracyConfig()):
 94        """SummarizationAccuracy initializer.
 95
 96        :param eval_algorithm_config: Summarization Accuracy evaluation algorithm config.
 97        """
 98        super().__init__(eval_algorithm_config)
 99        self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
100        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
101            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
102            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
103            meteor_keys=[METEOR_SCORE],
104            rouge_keys=[ROUGE_SCORE],
105            bertscore_keys=[BERT_SCORE],
106            rouge_type=eval_algorithm_config.rouge_type,
107            use_stemmer_for_rouge=eval_algorithm_config.use_stemmer_for_rouge,
108            bertscore_model=self.bertscore_model,
109        )
110        self.meteor_score = meteor_score
111        self.rouge_score = rouge_score
112        self.bert_score = bert_score
113        self.pipeline = TransformPipeline([meteor_score, rouge_score, bert_score])
114
115    @staticmethod
116    def _create_transforms(
117        target_output_keys: List[str],
118        model_output_keys: List[str],
119        meteor_keys: List[str],
120        rouge_keys: List[str],
121        bertscore_keys: List[str],
122        rouge_type: str,
123        use_stemmer_for_rouge: bool,
124        bertscore_model: Union[BertscoreHelperModel, ActorHandle],
125    ) -> Tuple[MeteorScore, RougeScore, BertScore]:
126        """Create a TransformPipeline containing summarization accuracy score transforms.
127
128        :param target_output_keys: See the corresponding parameter in MeteorScore, RougeScore, and BertScore.
129        :param model_output_keys: See the corresponding parameter in MeteorScore, RougeScore, and BertScore.
130        :param meteor_keys: The `output_keys` parameter for the returned MeteorScore instance.
131        :param rouge_keys: The `output_keys` parameter for the returned RougeScore instance.
132        :param bertscore_keys: The `output_keys` parameter for the returned BertScore instance.
133        :param rouge_type: See the corresponding parameter in RougeScore.
134        :param use_stemmer_for_rouge: See `use_stemmer` in RougeScore.
135        :param bertscore_model: A BertscoreHelperModel or Ray actor handle corresponding to a BertscoreHelperModel
136            (i.e. a shared resource) used in the creation of the returned BertScore instance.
137        :returns: A tuple containing the created MeteorScore, RougeScore, and BertScore instances.
138        """
139        meteor_transform = MeteorScore(
140            target_output_keys=target_output_keys,
141            model_output_keys=model_output_keys,
142            output_keys=meteor_keys,
143            allow_duplicate_input_keys=True,
144        )
145        rouge_transform = RougeScore(
146            target_output_keys=target_output_keys,
147            model_output_keys=model_output_keys,
148            output_keys=rouge_keys,
149            allow_duplicate_input_keys=True,
150            rouge_type=rouge_type,
151            use_stemmer=use_stemmer_for_rouge,
152        )
153        bert_transform = BertScore(
154            target_output_keys=target_output_keys,
155            model_output_keys=model_output_keys,
156            output_keys=bertscore_keys,
157            allow_duplicate_input_keys=True,
158            bertscore_model=bertscore_model,
159        )
160        return meteor_transform, rouge_transform, bert_transform
161
162    def evaluate_sample(self, target_output: str, model_output: str) -> List[EvalScore]:  # type: ignore[override]
163        """Compute summarization accuracy metrics for a single sample.
164
165        :param target_output: The expected/desired model output.
166        :param model_output: The actual model output.
167        :returns: A list of EvalScore objects, one for each of the summarization accuracy metrics.
168        """
169        sample = {
170            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
171            DatasetColumns.MODEL_OUTPUT.value.name: model_output,
172        }
173        output_record = self.pipeline.execute_record(sample)
174        assert_condition(
175            all(metric_name in output_record for metric_name in METRIC_NAMES),
176            "Summarization Accuracy evaluate_sample has computed an output that is missing at least one metric. "
177            f"The output record is {output_record}.",
178        )
179        return [EvalScore(name=metric_name, value=output_record[metric_name]) for metric_name in METRIC_NAMES]
180
181    def evaluate(
182        self,
183        model: Optional[ModelRunner] = None,
184        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
185        prompt_template: Optional[str] = None,
186        num_records: int = 100,
187        save: bool = False,
188        save_strategy: Optional[SaveStrategy] = None,
189    ) -> List[EvalOutput]:
190        """Compute summarization accuracy metrics on one or more datasets.
191
192        :param model: An instance of ModelRunner representing the model under evaluation.
193            If this argument is None, the `dataset_config` argument must not be None,
194            and must correspond to a dataset that already contains a column with model outputs.
195        :param dataset_config: Configures a single dataset or list of datasets used for the
196            evaluation. If not provided, this method will run evaluations using all of its
197            supported built-in datasets.
198        :param prompt_template: A template used to generate prompts that are fed to the model.
199            If not provided, defaults will be used. If provided, `model` must not be None.
200        :param num_records: The number of records to be sampled randomly from the input dataset(s)
201            used to perform the evaluation(s).
202        :param save: If set to true, prompt responses and scores will be saved to a file.
203        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
204            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
205            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
206
207        :return: A list of EvalOutput objects.
208        """
209        # Create a shared resource to be used during the evaluation.
210        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
211        # Create a new pipeline that uses the shared resource instead of self.bertscore_model.
212        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
213            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
214            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
215            meteor_keys=[METEOR_SCORE],
216            rouge_keys=[ROUGE_SCORE],
217            bertscore_keys=[BERT_SCORE],
218            rouge_type=self.rouge_score.rouge_type,
219            use_stemmer_for_rouge=self.rouge_score.use_stemmer,
220            bertscore_model=bertscore_shared_resource,
221        )
222        pipeline = TransformPipeline([meteor_score, rouge_score, bert_score])
223
224        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
225        eval_outputs = []
226        for dataset_config in dataset_configs:
227            dataset = get_dataset(dataset_config, num_records)
228            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
229            eval_output = evaluate_dataset(
230                dataset=dataset,
231                pipeline=pipeline,
232                dataset_name=dataset_config.dataset_name,
233                eval_name=self.eval_name,
234                metric_names=METRIC_NAMES,
235                eval_results_path=get_eval_results_path(),
236                model=model,
237                prompt_template=prompt_template,
238                agg_method=MEAN,
239                save=save,
240                save_strategy=save_strategy,
241            )
242            eval_outputs.append(eval_output)
243
244        cleanup_shared_resource(bertscore_shared_resource)
245        return eval_outputs
METRIC_NAMES = ['meteor', 'rouge', 'bertscore']
logger = <Logger fmeval.eval_algorithms.summarization_accuracy (WARNING)>
@dataclass(frozen=True)
class SummarizationAccuracyConfig(fmeval.eval_algorithms.eval_algorithm.EvalAlgorithmConfig):
42@dataclass(frozen=True)
43class SummarizationAccuracyConfig(EvalAlgorithmConfig):
44    """Configures the summarization accuracy evaluation algorithm.
45
46    :param rouge_type: ROUGE metric type.
47    :param use_stemmer_for_rouge: Whether to use stemmer when computing ROUGE metric.
48    :param model_type_for_bertscore: BERT model type to use for computing BERT score.
49    """
50
51    rouge_type: str = ROUGE_2
52    use_stemmer_for_rouge: bool = True
53    model_type_for_bertscore: str = BERTSCORE_DEFAULT_MODEL
54
55    def __post_init__(self):
56        require(
57            self.rouge_type in ROUGE_TYPES,
58            f"Invalid rouge_type: {self.rouge_type} requested in SummarizationAccuracyConfig. "
59            f"Please choose from acceptable values: {ROUGE_TYPES}.",
60        )
61        require(
62            BertscoreHelperModelTypes.model_is_allowed(self.model_type_for_bertscore),
63            f"Invalid model_type_for_bertscore: {self.model_type_for_bertscore} requested in "
64            f"SummarizationAccuracyConfig. Please choose from acceptable values: "
65            f"{BertscoreHelperModelTypes.model_list()}.",
66        )

Configures the summarization accuracy evaluation algorithm.

Parameters
  • rouge_type: ROUGE metric type.
  • use_stemmer_for_rouge: Whether to use stemmer when computing ROUGE metric.
  • model_type_for_bertscore: BERT model type to use for computing BERT score.
SummarizationAccuracyConfig( rouge_type: str = 'rouge2', use_stemmer_for_rouge: bool = True, model_type_for_bertscore: str = 'microsoft/deberta-xlarge-mnli')
rouge_type: str = 'rouge2'
use_stemmer_for_rouge: bool = True
model_type_for_bertscore: str = 'microsoft/deberta-xlarge-mnli'
class SummarizationAccuracy(fmeval.eval_algorithms.eval_algorithm.EvalAlgorithmInterface):
 69class SummarizationAccuracy(EvalAlgorithmInterface):
 70    """Summarization Accuracy evaluation algorithm.
 71
 72    This evaluation measures how accurately a model can summarize text. By default, we carry out this evaluation by benchmarking on two built-in datasets containing pairs of input text and target summary. The model summaries are then compared to the target summaries using three built-in metrics that measure how similar the summaries are in different ways:
 73
 74    1. ROUGE-N: ROUGE scores are a class of metrics that compute N-gram word overlaps between reference and model summary.  The metrics are case insensitive and the values are in the range of 0 (no match) to 1 (perfect match). It has the following configurable parameters which can be set in the `SummarizationAccuracyConfig`:
 75        * N: the length of N-grams to be matched. The three supported values are
 76            *  N=1 matches single words (unigrams)
 77            *  N=2 (default) matches word pairs (bigrams)
 78            *  N=L matches the longest common subsequence.  For computing the longest common subsequence, order is accounted for, but consecutiveness is discounted. E.g., for model summary = "It is autumn"  and  reference = "It is once again autumn" we have that LCS(prediction, reference)=3.
 79        * use_stemmer: If True (default), uses [Porter stemmer](https://www.cs.toronto.edu/~frank/csc2501/Readings/R2_Porter/Porter-1980.pdf) to strip word suffices. For example, "raining" → "rain".
 80    To obtain ROUGE-N, N-gram precision and recall are computed. Those are then aggregated into the final score:
 81    ROUGE-N = 2 * (precision_N * recall_N) / (precision_N + recall_N).
 82
 83    2. [Meteor](https://aclanthology.org/W05-0909.pdf) is similar to ROUGE-1, but includes stemming (with Porter stemmer) and synonym matching via synonym lists (e.g. “fall” → “autumn”).  The words that are matched by the Meteor score are marked in yellow above. Because Meteor can match synonyms, it is more flexible to paraphrasing than ROUGE.
 84    2. [BERTScore](https://arxiv.org/pdf/1904.09675.pdf) uses a second ML model (from the BERT family) to compute sentence embeddings and compare their cosine similarity. This score may account for additional linguistic flexibility over ROUGE and METEOR since semantically similar sentences should be embedded closer to each other.
 85
 86    Parameters which can be set in the `SummarizationAccuracyConfig` are:
 87    * model_name: Name of the model to be used for scoring, choose one of "microsoft/deberta-xlarge-mnli"  (default) and “roberta-large-mnli" .
 88
 89
 90    """
 91
 92    eval_name = EvalAlgorithm.SUMMARIZATION_ACCURACY.value
 93
 94    def __init__(self, eval_algorithm_config: SummarizationAccuracyConfig = SummarizationAccuracyConfig()):
 95        """SummarizationAccuracy initializer.
 96
 97        :param eval_algorithm_config: Summarization Accuracy evaluation algorithm config.
 98        """
 99        super().__init__(eval_algorithm_config)
100        self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
101        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
102            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
103            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
104            meteor_keys=[METEOR_SCORE],
105            rouge_keys=[ROUGE_SCORE],
106            bertscore_keys=[BERT_SCORE],
107            rouge_type=eval_algorithm_config.rouge_type,
108            use_stemmer_for_rouge=eval_algorithm_config.use_stemmer_for_rouge,
109            bertscore_model=self.bertscore_model,
110        )
111        self.meteor_score = meteor_score
112        self.rouge_score = rouge_score
113        self.bert_score = bert_score
114        self.pipeline = TransformPipeline([meteor_score, rouge_score, bert_score])
115
116    @staticmethod
117    def _create_transforms(
118        target_output_keys: List[str],
119        model_output_keys: List[str],
120        meteor_keys: List[str],
121        rouge_keys: List[str],
122        bertscore_keys: List[str],
123        rouge_type: str,
124        use_stemmer_for_rouge: bool,
125        bertscore_model: Union[BertscoreHelperModel, ActorHandle],
126    ) -> Tuple[MeteorScore, RougeScore, BertScore]:
127        """Create a TransformPipeline containing summarization accuracy score transforms.
128
129        :param target_output_keys: See the corresponding parameter in MeteorScore, RougeScore, and BertScore.
130        :param model_output_keys: See the corresponding parameter in MeteorScore, RougeScore, and BertScore.
131        :param meteor_keys: The `output_keys` parameter for the returned MeteorScore instance.
132        :param rouge_keys: The `output_keys` parameter for the returned RougeScore instance.
133        :param bertscore_keys: The `output_keys` parameter for the returned BertScore instance.
134        :param rouge_type: See the corresponding parameter in RougeScore.
135        :param use_stemmer_for_rouge: See `use_stemmer` in RougeScore.
136        :param bertscore_model: A BertscoreHelperModel or Ray actor handle corresponding to a BertscoreHelperModel
137            (i.e. a shared resource) used in the creation of the returned BertScore instance.
138        :returns: A tuple containing the created MeteorScore, RougeScore, and BertScore instances.
139        """
140        meteor_transform = MeteorScore(
141            target_output_keys=target_output_keys,
142            model_output_keys=model_output_keys,
143            output_keys=meteor_keys,
144            allow_duplicate_input_keys=True,
145        )
146        rouge_transform = RougeScore(
147            target_output_keys=target_output_keys,
148            model_output_keys=model_output_keys,
149            output_keys=rouge_keys,
150            allow_duplicate_input_keys=True,
151            rouge_type=rouge_type,
152            use_stemmer=use_stemmer_for_rouge,
153        )
154        bert_transform = BertScore(
155            target_output_keys=target_output_keys,
156            model_output_keys=model_output_keys,
157            output_keys=bertscore_keys,
158            allow_duplicate_input_keys=True,
159            bertscore_model=bertscore_model,
160        )
161        return meteor_transform, rouge_transform, bert_transform
162
163    def evaluate_sample(self, target_output: str, model_output: str) -> List[EvalScore]:  # type: ignore[override]
164        """Compute summarization accuracy metrics for a single sample.
165
166        :param target_output: The expected/desired model output.
167        :param model_output: The actual model output.
168        :returns: A list of EvalScore objects, one for each of the summarization accuracy metrics.
169        """
170        sample = {
171            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
172            DatasetColumns.MODEL_OUTPUT.value.name: model_output,
173        }
174        output_record = self.pipeline.execute_record(sample)
175        assert_condition(
176            all(metric_name in output_record for metric_name in METRIC_NAMES),
177            "Summarization Accuracy evaluate_sample has computed an output that is missing at least one metric. "
178            f"The output record is {output_record}.",
179        )
180        return [EvalScore(name=metric_name, value=output_record[metric_name]) for metric_name in METRIC_NAMES]
181
182    def evaluate(
183        self,
184        model: Optional[ModelRunner] = None,
185        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
186        prompt_template: Optional[str] = None,
187        num_records: int = 100,
188        save: bool = False,
189        save_strategy: Optional[SaveStrategy] = None,
190    ) -> List[EvalOutput]:
191        """Compute summarization accuracy metrics on one or more datasets.
192
193        :param model: An instance of ModelRunner representing the model under evaluation.
194            If this argument is None, the `dataset_config` argument must not be None,
195            and must correspond to a dataset that already contains a column with model outputs.
196        :param dataset_config: Configures a single dataset or list of datasets used for the
197            evaluation. If not provided, this method will run evaluations using all of its
198            supported built-in datasets.
199        :param prompt_template: A template used to generate prompts that are fed to the model.
200            If not provided, defaults will be used. If provided, `model` must not be None.
201        :param num_records: The number of records to be sampled randomly from the input dataset(s)
202            used to perform the evaluation(s).
203        :param save: If set to true, prompt responses and scores will be saved to a file.
204        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
205            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
206            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
207
208        :return: A list of EvalOutput objects.
209        """
210        # Create a shared resource to be used during the evaluation.
211        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
212        # Create a new pipeline that uses the shared resource instead of self.bertscore_model.
213        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
214            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
215            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
216            meteor_keys=[METEOR_SCORE],
217            rouge_keys=[ROUGE_SCORE],
218            bertscore_keys=[BERT_SCORE],
219            rouge_type=self.rouge_score.rouge_type,
220            use_stemmer_for_rouge=self.rouge_score.use_stemmer,
221            bertscore_model=bertscore_shared_resource,
222        )
223        pipeline = TransformPipeline([meteor_score, rouge_score, bert_score])
224
225        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
226        eval_outputs = []
227        for dataset_config in dataset_configs:
228            dataset = get_dataset(dataset_config, num_records)
229            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
230            eval_output = evaluate_dataset(
231                dataset=dataset,
232                pipeline=pipeline,
233                dataset_name=dataset_config.dataset_name,
234                eval_name=self.eval_name,
235                metric_names=METRIC_NAMES,
236                eval_results_path=get_eval_results_path(),
237                model=model,
238                prompt_template=prompt_template,
239                agg_method=MEAN,
240                save=save,
241                save_strategy=save_strategy,
242            )
243            eval_outputs.append(eval_output)
244
245        cleanup_shared_resource(bertscore_shared_resource)
246        return eval_outputs

Summarization Accuracy evaluation algorithm.

This evaluation measures how accurately a model can summarize text. By default, we carry out this evaluation by benchmarking on two built-in datasets containing pairs of input text and target summary. The model summaries are then compared to the target summaries using three built-in metrics that measure how similar the summaries are in different ways:

  1. ROUGE-N: ROUGE scores are a class of metrics that compute N-gram word overlaps between reference and model summary. The metrics are case insensitive and the values are in the range of 0 (no match) to 1 (perfect match). It has the following configurable parameters which can be set in the SummarizationAccuracyConfig:

    • N: the length of N-grams to be matched. The three supported values are
      • N=1 matches single words (unigrams)
      • N=2 (default) matches word pairs (bigrams)
      • N=L matches the longest common subsequence. For computing the longest common subsequence, order is accounted for, but consecutiveness is discounted. E.g., for model summary = "It is autumn" and reference = "It is once again autumn" we have that LCS(prediction, reference)=3.
    • use_stemmer: If True (default), uses Porter stemmer to strip word suffices. For example, "raining" → "rain". To obtain ROUGE-N, N-gram precision and recall are computed. Those are then aggregated into the final score: ROUGE-N = 2 * (precision_N * recall_N) / (precision_N + recall_N).
  2. Meteor is similar to ROUGE-1, but includes stemming (with Porter stemmer) and synonym matching via synonym lists (e.g. “fall” → “autumn”). The words that are matched by the Meteor score are marked in yellow above. Because Meteor can match synonyms, it is more flexible to paraphrasing than ROUGE.

  3. BERTScore uses a second ML model (from the BERT family) to compute sentence embeddings and compare their cosine similarity. This score may account for additional linguistic flexibility over ROUGE and METEOR since semantically similar sentences should be embedded closer to each other.

Parameters which can be set in the SummarizationAccuracyConfig are:

  • model_name: Name of the model to be used for scoring, choose one of "microsoft/deberta-xlarge-mnli" (default) and “roberta-large-mnli" .
SummarizationAccuracy( eval_algorithm_config: SummarizationAccuracyConfig = SummarizationAccuracyConfig(rouge_type='rouge2', use_stemmer_for_rouge=True, model_type_for_bertscore='microsoft/deberta-xlarge-mnli'))
 94    def __init__(self, eval_algorithm_config: SummarizationAccuracyConfig = SummarizationAccuracyConfig()):
 95        """SummarizationAccuracy initializer.
 96
 97        :param eval_algorithm_config: Summarization Accuracy evaluation algorithm config.
 98        """
 99        super().__init__(eval_algorithm_config)
100        self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
101        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
102            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
103            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
104            meteor_keys=[METEOR_SCORE],
105            rouge_keys=[ROUGE_SCORE],
106            bertscore_keys=[BERT_SCORE],
107            rouge_type=eval_algorithm_config.rouge_type,
108            use_stemmer_for_rouge=eval_algorithm_config.use_stemmer_for_rouge,
109            bertscore_model=self.bertscore_model,
110        )
111        self.meteor_score = meteor_score
112        self.rouge_score = rouge_score
113        self.bert_score = bert_score
114        self.pipeline = TransformPipeline([meteor_score, rouge_score, bert_score])

SummarizationAccuracy initializer.

Parameters
  • eval_algorithm_config: Summarization Accuracy evaluation algorithm config.
eval_name = 'summarization_accuracy'
bertscore_model
meteor_score
rouge_score
bert_score
pipeline
def evaluate_sample( self, target_output: str, model_output: str) -> List[fmeval.eval_algorithms.EvalScore]:
163    def evaluate_sample(self, target_output: str, model_output: str) -> List[EvalScore]:  # type: ignore[override]
164        """Compute summarization accuracy metrics for a single sample.
165
166        :param target_output: The expected/desired model output.
167        :param model_output: The actual model output.
168        :returns: A list of EvalScore objects, one for each of the summarization accuracy metrics.
169        """
170        sample = {
171            DatasetColumns.TARGET_OUTPUT.value.name: target_output,
172            DatasetColumns.MODEL_OUTPUT.value.name: model_output,
173        }
174        output_record = self.pipeline.execute_record(sample)
175        assert_condition(
176            all(metric_name in output_record for metric_name in METRIC_NAMES),
177            "Summarization Accuracy evaluate_sample has computed an output that is missing at least one metric. "
178            f"The output record is {output_record}.",
179        )
180        return [EvalScore(name=metric_name, value=output_record[metric_name]) for metric_name in METRIC_NAMES]

Compute summarization accuracy metrics for a single sample.

Parameters
  • target_output: The expected/desired model output.
  • model_output: The actual model output. :returns: A list of EvalScore objects, one for each of the summarization accuracy metrics.
def evaluate( self, model: Optional[fmeval.model_runners.model_runner.ModelRunner] = None, dataset_config: Union[fmeval.data_loaders.data_config.DataConfig, List[fmeval.data_loaders.data_config.DataConfig], NoneType] = None, prompt_template: Optional[str] = None, num_records: int = 100, save: bool = False, save_strategy: Optional[fmeval.eval_algorithms.save_strategy.SaveStrategy] = None) -> List[fmeval.eval_algorithms.EvalOutput]:
182    def evaluate(
183        self,
184        model: Optional[ModelRunner] = None,
185        dataset_config: Optional[Union[DataConfig, List[DataConfig]]] = None,
186        prompt_template: Optional[str] = None,
187        num_records: int = 100,
188        save: bool = False,
189        save_strategy: Optional[SaveStrategy] = None,
190    ) -> List[EvalOutput]:
191        """Compute summarization accuracy metrics on one or more datasets.
192
193        :param model: An instance of ModelRunner representing the model under evaluation.
194            If this argument is None, the `dataset_config` argument must not be None,
195            and must correspond to a dataset that already contains a column with model outputs.
196        :param dataset_config: Configures a single dataset or list of datasets used for the
197            evaluation. If not provided, this method will run evaluations using all of its
198            supported built-in datasets.
199        :param prompt_template: A template used to generate prompts that are fed to the model.
200            If not provided, defaults will be used. If provided, `model` must not be None.
201        :param num_records: The number of records to be sampled randomly from the input dataset(s)
202            used to perform the evaluation(s).
203        :param save: If set to true, prompt responses and scores will be saved to a file.
204        :param save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not
205            specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable.
206            If that environment variable is also not configured, it will be saved to the default path `/tmp/eval_results/`.
207
208        :return: A list of EvalOutput objects.
209        """
210        # Create a shared resource to be used during the evaluation.
211        bertscore_shared_resource = create_shared_resource(self.bertscore_model)
212        # Create a new pipeline that uses the shared resource instead of self.bertscore_model.
213        meteor_score, rouge_score, bert_score = SummarizationAccuracy._create_transforms(
214            target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
215            model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
216            meteor_keys=[METEOR_SCORE],
217            rouge_keys=[ROUGE_SCORE],
218            bertscore_keys=[BERT_SCORE],
219            rouge_type=self.rouge_score.rouge_type,
220            use_stemmer_for_rouge=self.rouge_score.use_stemmer,
221            bertscore_model=bertscore_shared_resource,
222        )
223        pipeline = TransformPipeline([meteor_score, rouge_score, bert_score])
224
225        dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
226        eval_outputs = []
227        for dataset_config in dataset_configs:
228            dataset = get_dataset(dataset_config, num_records)
229            validate_dataset(dataset, [DatasetColumns.MODEL_INPUT.value.name, DatasetColumns.TARGET_OUTPUT.value.name])
230            eval_output = evaluate_dataset(
231                dataset=dataset,
232                pipeline=pipeline,
233                dataset_name=dataset_config.dataset_name,
234                eval_name=self.eval_name,
235                metric_names=METRIC_NAMES,
236                eval_results_path=get_eval_results_path(),
237                model=model,
238                prompt_template=prompt_template,
239                agg_method=MEAN,
240                save=save,
241                save_strategy=save_strategy,
242            )
243            eval_outputs.append(eval_output)
244
245        cleanup_shared_resource(bertscore_shared_resource)
246        return eval_outputs

Compute summarization accuracy metrics on one or more datasets.

Parameters
  • model: An instance of ModelRunner representing the model under evaluation. If this argument is None, the dataset_config argument must not be None, and must correspond to a dataset that already contains a column with model outputs.
  • dataset_config: Configures a single dataset or list of datasets used for the evaluation. If not provided, this method will run evaluations using all of its supported built-in datasets.
  • prompt_template: A template used to generate prompts that are fed to the model. If not provided, defaults will be used. If provided, model must not be None.
  • num_records: The number of records to be sampled randomly from the input dataset(s) used to perform the evaluation(s).
  • save: If set to true, prompt responses and scores will be saved to a file.
  • save_strategy: Specifies the strategy to use the save the localized outputs of the evaluations. If not specified, it will save it to the path that can be configured by the EVAL_RESULTS_PATH environment variable. If that environment variable is also not configured, it will be saved to the default path /tmp/eval_results/.
Returns

A list of EvalOutput objects.