fmeval.transforms.summarization_accuracy_metrics

  1import ray
  2import nltk
  3import evaluate as hf_evaluate
  4
  5from abc import abstractmethod
  6from typing import Any, Dict, Union, List, Optional
  7from ray.actor import ActorHandle
  8from nltk import word_tokenize
  9from nltk.translate import meteor_score
 10
 11from fmeval.transforms.transform import Transform
 12from fmeval.transforms.util import validate_call
 13from fmeval.constants import BERTSCORE_DEFAULT_MODEL
 14from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModel
 15from fmeval.util import assert_condition
 16
 17METEOR_SCORE = "meteor"
 18ROUGE_SCORE = "rouge"
 19BERT_SCORE = "bertscore"
 20
 21# rouge constants
 22ROUGE_1 = "rouge1"
 23ROUGE_2 = "rouge2"
 24ROUGE_L = "rougeL"
 25
 26ROUGE_TYPES = [ROUGE_1, ROUGE_2, ROUGE_L]
 27
 28
 29class SummarizationAccuracyMetric(Transform):
 30    """The abstract base class for summarization accuracy metric transforms.
 31
 32    Concrete subclasses of SummarizationAccuracyMetric should simply implement the
 33    `compute_metric` method and their own __init__ method. Subclasses need not implement
 34    the __call__ method, as it is already implemented in this class, but are
 35    free to do so if additional customization is required.
 36    """
 37
 38    def __init__(
 39        self,
 40        target_output_keys: Optional[List[str]],
 41        model_output_keys: List[str],
 42        output_keys: List[str],
 43        allow_duplicate_input_keys: bool,
 44        target_output_keys_provider: Optional[str],
 45        *args,
 46        **kwargs,
 47    ):
 48        """SummarizationAccuracyMetric initializer.
 49
 50        Note that the ordering of the elements in `model_output_keys`, and `output_keys`
 51        must match, i.e. the kth element of kth element of `model_output_keys` is used
 52        to compute the kth metric, which has an output key of `output_keys[k]`.
 53
 54        :param target_output_keys: The keys corresponding to target outputs. If this is
 55            set to None, then we will use `target_output_keys_provider` to get the
 56            list of target outputs.
 57        :param model_output_keys: The keys corresponding to model outputs.
 58        :param output_keys: The output keys for this Transform, which correspond
 59            to the metrics/scores that get computed.
 60        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
 61            `target_output_keys` and `model_output_keys`. This parameter is usually
 62            False.
 63        :param target_output_keys_provider: The key corresponding to a list of target
 64            outputs. Will only be used if `target_output_keys` is set to None.
 65        :param *args: Variable length argument list.
 66        :param **kwargs: Arbitrary keyword arguments.
 67        """
 68        assert_condition(
 69            len(model_output_keys) == len(output_keys),
 70            "len(model_output_keys) and len(output_keys) should match. "
 71            f"len(model_output_keys) is "
 72            f"{len(model_output_keys)}, and len(output_keys) is {len(output_keys)}.",
 73        )
 74        if target_output_keys is None:
 75            assert_condition(
 76                target_output_keys_provider is not None,
 77                f"target_output_keys is {target_output_keys}, but target_output_keys_provider"
 78                f" (the fallback value) is {target_output_keys_provider} which is invalid.",
 79            )
 80        super().__init__(
 81            target_output_keys,
 82            model_output_keys,
 83            output_keys,
 84            allow_duplicate_input_keys,
 85            target_output_keys_provider,
 86            *args,
 87            **kwargs,
 88        )
 89        input_keys = target_output_keys if target_output_keys else [target_output_keys_provider]  # type: ignore
 90        self.register_input_output_keys(
 91            input_keys + model_output_keys,
 92            output_keys,
 93            allow_duplicates=allow_duplicate_input_keys,
 94        )
 95        self.target_output_keys = target_output_keys
 96        self.model_output_keys = model_output_keys
 97        self.target_output_keys_provider = target_output_keys_provider
 98
 99    @validate_call
100    def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]:
101        """Augment the input record with metrics computed via self.compute_metric.
102        The max score is computed over all possible targets represented by
103        self.target_output_keys and stored in the input record.
104
105        :param record: The input record.
106        :returns: The input record with metrics added in.
107        """
108        target_output_list = (
109            [record[target_output_key] for target_output_key in self.target_output_keys]
110            if self.target_output_keys
111            else record[self.target_output_keys_provider]  # type: ignore[index]
112        )
113        for model_output_key, output_key in zip(self.model_output_keys, self.output_keys):
114            scores = [self.compute_metric(target, record[model_output_key]) for target in target_output_list]
115            record[output_key] = max(scores)
116        return record
117
118    @abstractmethod
119    def compute_metric(self, target_output: str, model_output: str) -> float:
120        """Compute the metric that is specific to this Transform.
121
122        :param target_output: The target/reference output.
123        :param model_output: The actual output produced by the model.
124        :returns: A float representing the computed metric value.
125        """
126
127
128class MeteorScore(SummarizationAccuracyMetric):
129    """This Transform augments an input record with the METEOR metric, computed from target and model outputs.
130
131    METEOR is a metric for text similarity between the machine-produced summary
132    and human-produced reference summaries.
133    Unigrams can be matched based on their surface forms, stemmed forms,
134    and meanings; furthermore, METEOR can be easily extended to include more
135    advanced matching strategies. Once all generalized unigram matches
136    between the two strings have been found, METEOR computes a score for
137    this matching using a combination of unigram-precision, unigram-recall, and
138    a measure of fragmentation that is designed to directly capture how
139    well-ordered the matched words in the machine translation are in relation
140    to the reference.
141    """
142
143    def __init__(
144        self,
145        target_output_keys: Optional[List[str]],
146        model_output_keys: List[str],
147        output_keys: List[str],
148        allow_duplicate_input_keys: bool,
149        target_output_keys_provider: Optional[str] = None,
150        load_modules: bool = True,
151    ):
152        """MeteorScore initializer.
153
154        :param target_output_keys: The keys corresponding to target outputs. If this is
155            set to None, then we will use `target_output_keys_provider` to get the
156            list of target outputs.
157        :param model_output_keys: The keys corresponding to model outputs.
158        :param output_keys: The output keys for this Transform, which correspond
159            to the Meteor scores that get computed.
160        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
161            `target_output_keys` and `model_output_keys`. This parameter is usually
162            False.
163        :param target_output_keys_provider: The key corresponding to a list of target
164            outputs. Will only be used if `target_output_keys` is set to None.
165        :param load_modules: Whether to load the meteor helper modules.
166        """
167        super().__init__(
168            target_output_keys,
169            model_output_keys,
170            output_keys,
171            allow_duplicate_input_keys,
172            target_output_keys_provider,
173            # The first instance of this class that gets created will
174            # load the helper modules, so copies of this instance
175            # need not load them again.
176            load_modules=False,
177        )
178        if load_modules:
179            MeteorScore._load_modules()
180
181    @staticmethod
182    def _load_modules() -> None:  # pragma: no cover
183        """Load helper modules required by meteor metric.
184
185        :returns: None
186        """
187        nltk.download("wordnet")
188        nltk.download("punkt")
189        nltk.download("punkt_tab")
190        nltk.download("omw-1.4")
191
192    def compute_metric(self, target_output: str, model_output: str) -> float:
193        """Compute the Meteor metric.
194
195        :param target_output: The target/reference output.
196        :param model_output: The actual output produced by the model.
197        :returns: The meteor metric value.
198        """
199        return meteor_score.single_meteor_score(
200            reference=word_tokenize(target_output),
201            hypothesis=word_tokenize(model_output),
202        )
203
204
205class RougeScore(SummarizationAccuracyMetric):
206    """This transform augments an input record with the ROUGE score, computed from target and model outputs.
207
208    The ROUGE-N, where N=[1,2,L], score is a standard metric for summarization quality.
209    It computes the word overlap between the reference and model summary. Given that this metric is based on simple
210    word overlap statistics, it works best for extractive summaries.
211    Note that if we rephrase the summary without changing its meaning the ROUGE-N score will drop.
212
213    Reference: https://huggingface.co/spaces/evaluate-metric/rouge
214    """
215
216    def __init__(
217        self,
218        target_output_keys: Optional[List[str]],
219        model_output_keys: List[str],
220        output_keys: List[str],
221        allow_duplicate_input_keys: bool,
222        target_output_keys_provider: Optional[str] = None,
223        rouge_type: str = ROUGE_2,
224        use_stemmer: bool = True,
225    ):
226        """RougeScore initializer.
227
228        :param target_output_keys: The keys corresponding to target outputs. If this is
229            set to None, then we will use `target_output_keys_provider` to get the
230            list of target outputs.
231        :param model_output_keys: The keys corresponding to model outputs.
232        :param output_keys: The output keys for this Transform, which correspond
233            to the Rouge scores that get computed.
234        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
235            `target_output_keys` and `model_output_keys`. This parameter is usually
236            False.
237        :param target_output_keys_provider: The key corresponding to a list of target
238            outputs. Will only be used if `target_output_keys` is set to None.
239        :param rouge_type: Which ROUGE type to use (1, 2, L).
240        :param use_stemmer: Whether to use a stemmer for ROUGE.
241        """
242        super().__init__(
243            target_output_keys,
244            model_output_keys,
245            output_keys,
246            allow_duplicate_input_keys,
247            target_output_keys_provider,
248            rouge_type=rouge_type,
249            use_stemmer=use_stemmer,
250        )
251        self.rouge_type = rouge_type
252        self.use_stemmer = use_stemmer
253        self.rouge_metric = hf_evaluate.load("rouge")
254
255    def compute_metric(self, target_output: str, model_output: str) -> float:
256        """Compute the ROUGE metric.
257
258        :param target_output: The target/reference output.
259        :param model_output: The actual output produced by the model.
260        :returns: The ROUGE metric value.
261        """
262        return self.rouge_metric.compute(
263            predictions=[model_output],
264            references=[target_output],
265            use_stemmer=self.use_stemmer,
266            rouge_types=[self.rouge_type],
267        )[self.rouge_type]
268
269
270class BertScore(SummarizationAccuracyMetric):
271    """This transform augments an input record with the BERT score, computed from target and model outputs.
272
273    BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences
274    under a learned model, typically, from the BERT family.
275    This score may lead to increased flexibility compared to ROUGE and METEOR in terms of rephrasing since
276    semantically similar sentences are (typically) embedded similarly.
277
278    See https://huggingface.co/spaces/evaluate-metric/bertscore
279    """
280
281    def __init__(
282        self,
283        target_output_keys: Optional[List[str]],
284        model_output_keys: List[str],
285        output_keys: List[str],
286        allow_duplicate_input_keys: bool,
287        target_output_keys_provider: Optional[str] = None,
288        bertscore_model: Union[BertscoreHelperModel, ActorHandle] = BertscoreHelperModel(BERTSCORE_DEFAULT_MODEL),
289    ):
290        """BertScore initializer.
291
292        :param target_output_keys: The keys corresponding to target outputs. If this is
293            set to None, then we will use `target_output_keys_provider` to get the
294            list of target outputs.
295        :param model_output_keys: The keys corresponding to model outputs.
296        :param output_keys: The output keys for this Transform, which correspond
297            to the BERT_SCORES that get computed.
298        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
299            `target_output_keys` and `model_output_keys`. This parameter is usually
300            False.
301        :param target_output_keys_provider: The key corresponding to a list of target
302            outputs. Will only be used if `target_output_keys` is set to None.
303        :param bertscore_model: A BertscoreHelperModel instance or a Ray actor handle for a BertscoreHelperModel.
304            If no model is provided, the parameter will be set to the default BertscoreHelperModel
305        """
306        super().__init__(
307            target_output_keys,
308            model_output_keys,
309            output_keys,
310            allow_duplicate_input_keys,
311            target_output_keys_provider,
312            bertscore_model,
313        )
314        self.bertscore_model = bertscore_model
315
316    def compute_metric(self, target_output: str, model_output: str) -> float:
317        """Compute the BERTScore metric.
318
319        :param target_output: The target/reference output.
320        :param model_output: The actual output produced by the model.
321        :returns: The BERT metric value.
322        """
323        if isinstance(self.bertscore_model, BertscoreHelperModel):
324            return self.bertscore_model.get_helper_scores(target_output, model_output)
325        else:
326            return ray.get(  # type: ignore[return-value]
327                self.bertscore_model.get_helper_scores.remote(target_output, model_output)  # type: ignore[union-attr]
328            )
METEOR_SCORE = 'meteor'
ROUGE_SCORE = 'rouge'
BERT_SCORE = 'bertscore'
ROUGE_1 = 'rouge1'
ROUGE_2 = 'rouge2'
ROUGE_L = 'rougeL'
ROUGE_TYPES = ['rouge1', 'rouge2', 'rougeL']
class SummarizationAccuracyMetric(fmeval.transforms.transform.Transform):
 30class SummarizationAccuracyMetric(Transform):
 31    """The abstract base class for summarization accuracy metric transforms.
 32
 33    Concrete subclasses of SummarizationAccuracyMetric should simply implement the
 34    `compute_metric` method and their own __init__ method. Subclasses need not implement
 35    the __call__ method, as it is already implemented in this class, but are
 36    free to do so if additional customization is required.
 37    """
 38
 39    def __init__(
 40        self,
 41        target_output_keys: Optional[List[str]],
 42        model_output_keys: List[str],
 43        output_keys: List[str],
 44        allow_duplicate_input_keys: bool,
 45        target_output_keys_provider: Optional[str],
 46        *args,
 47        **kwargs,
 48    ):
 49        """SummarizationAccuracyMetric initializer.
 50
 51        Note that the ordering of the elements in `model_output_keys`, and `output_keys`
 52        must match, i.e. the kth element of kth element of `model_output_keys` is used
 53        to compute the kth metric, which has an output key of `output_keys[k]`.
 54
 55        :param target_output_keys: The keys corresponding to target outputs. If this is
 56            set to None, then we will use `target_output_keys_provider` to get the
 57            list of target outputs.
 58        :param model_output_keys: The keys corresponding to model outputs.
 59        :param output_keys: The output keys for this Transform, which correspond
 60            to the metrics/scores that get computed.
 61        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
 62            `target_output_keys` and `model_output_keys`. This parameter is usually
 63            False.
 64        :param target_output_keys_provider: The key corresponding to a list of target
 65            outputs. Will only be used if `target_output_keys` is set to None.
 66        :param *args: Variable length argument list.
 67        :param **kwargs: Arbitrary keyword arguments.
 68        """
 69        assert_condition(
 70            len(model_output_keys) == len(output_keys),
 71            "len(model_output_keys) and len(output_keys) should match. "
 72            f"len(model_output_keys) is "
 73            f"{len(model_output_keys)}, and len(output_keys) is {len(output_keys)}.",
 74        )
 75        if target_output_keys is None:
 76            assert_condition(
 77                target_output_keys_provider is not None,
 78                f"target_output_keys is {target_output_keys}, but target_output_keys_provider"
 79                f" (the fallback value) is {target_output_keys_provider} which is invalid.",
 80            )
 81        super().__init__(
 82            target_output_keys,
 83            model_output_keys,
 84            output_keys,
 85            allow_duplicate_input_keys,
 86            target_output_keys_provider,
 87            *args,
 88            **kwargs,
 89        )
 90        input_keys = target_output_keys if target_output_keys else [target_output_keys_provider]  # type: ignore
 91        self.register_input_output_keys(
 92            input_keys + model_output_keys,
 93            output_keys,
 94            allow_duplicates=allow_duplicate_input_keys,
 95        )
 96        self.target_output_keys = target_output_keys
 97        self.model_output_keys = model_output_keys
 98        self.target_output_keys_provider = target_output_keys_provider
 99
100    @validate_call
101    def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]:
102        """Augment the input record with metrics computed via self.compute_metric.
103        The max score is computed over all possible targets represented by
104        self.target_output_keys and stored in the input record.
105
106        :param record: The input record.
107        :returns: The input record with metrics added in.
108        """
109        target_output_list = (
110            [record[target_output_key] for target_output_key in self.target_output_keys]
111            if self.target_output_keys
112            else record[self.target_output_keys_provider]  # type: ignore[index]
113        )
114        for model_output_key, output_key in zip(self.model_output_keys, self.output_keys):
115            scores = [self.compute_metric(target, record[model_output_key]) for target in target_output_list]
116            record[output_key] = max(scores)
117        return record
118
119    @abstractmethod
120    def compute_metric(self, target_output: str, model_output: str) -> float:
121        """Compute the metric that is specific to this Transform.
122
123        :param target_output: The target/reference output.
124        :param model_output: The actual output produced by the model.
125        :returns: A float representing the computed metric value.
126        """

The abstract base class for summarization accuracy metric transforms.

Concrete subclasses of SummarizationAccuracyMetric should simply implement the compute_metric method and their own __init__ method. Subclasses need not implement the __call__ method, as it is already implemented in this class, but are free to do so if additional customization is required.

SummarizationAccuracyMetric( target_output_keys: Optional[List[str]], model_output_keys: List[str], output_keys: List[str], allow_duplicate_input_keys: bool, target_output_keys_provider: Optional[str], *args, **kwargs)
39    def __init__(
40        self,
41        target_output_keys: Optional[List[str]],
42        model_output_keys: List[str],
43        output_keys: List[str],
44        allow_duplicate_input_keys: bool,
45        target_output_keys_provider: Optional[str],
46        *args,
47        **kwargs,
48    ):
49        """SummarizationAccuracyMetric initializer.
50
51        Note that the ordering of the elements in `model_output_keys`, and `output_keys`
52        must match, i.e. the kth element of kth element of `model_output_keys` is used
53        to compute the kth metric, which has an output key of `output_keys[k]`.
54
55        :param target_output_keys: The keys corresponding to target outputs. If this is
56            set to None, then we will use `target_output_keys_provider` to get the
57            list of target outputs.
58        :param model_output_keys: The keys corresponding to model outputs.
59        :param output_keys: The output keys for this Transform, which correspond
60            to the metrics/scores that get computed.
61        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
62            `target_output_keys` and `model_output_keys`. This parameter is usually
63            False.
64        :param target_output_keys_provider: The key corresponding to a list of target
65            outputs. Will only be used if `target_output_keys` is set to None.
66        :param *args: Variable length argument list.
67        :param **kwargs: Arbitrary keyword arguments.
68        """
69        assert_condition(
70            len(model_output_keys) == len(output_keys),
71            "len(model_output_keys) and len(output_keys) should match. "
72            f"len(model_output_keys) is "
73            f"{len(model_output_keys)}, and len(output_keys) is {len(output_keys)}.",
74        )
75        if target_output_keys is None:
76            assert_condition(
77                target_output_keys_provider is not None,
78                f"target_output_keys is {target_output_keys}, but target_output_keys_provider"
79                f" (the fallback value) is {target_output_keys_provider} which is invalid.",
80            )
81        super().__init__(
82            target_output_keys,
83            model_output_keys,
84            output_keys,
85            allow_duplicate_input_keys,
86            target_output_keys_provider,
87            *args,
88            **kwargs,
89        )
90        input_keys = target_output_keys if target_output_keys else [target_output_keys_provider]  # type: ignore
91        self.register_input_output_keys(
92            input_keys + model_output_keys,
93            output_keys,
94            allow_duplicates=allow_duplicate_input_keys,
95        )
96        self.target_output_keys = target_output_keys
97        self.model_output_keys = model_output_keys
98        self.target_output_keys_provider = target_output_keys_provider

SummarizationAccuracyMetric initializer.

Note that the ordering of the elements in model_output_keys, and output_keys must match, i.e. the kth element of kth element of model_output_keys is used to compute the kth metric, which has an output key of output_keys[k].

Parameters
  • target_output_keys: The keys corresponding to target outputs. If this is set to None, then we will use target_output_keys_provider to get the list of target outputs.
  • model_output_keys: The keys corresponding to model outputs.
  • output_keys: The output keys for this Transform, which correspond to the metrics/scores that get computed.
  • allow_duplicate_input_keys: Whether to allow duplicate keys in target_output_keys and model_output_keys. This parameter is usually False.
  • target_output_keys_provider: The key corresponding to a list of target outputs. Will only be used if target_output_keys is set to None.
  • *args: Variable length argument list.
  • **kwargs: Arbitrary keyword arguments.
target_output_keys
model_output_keys
target_output_keys_provider
@abstractmethod
def compute_metric(self, target_output: str, model_output: str) -> float:
119    @abstractmethod
120    def compute_metric(self, target_output: str, model_output: str) -> float:
121        """Compute the metric that is specific to this Transform.
122
123        :param target_output: The target/reference output.
124        :param model_output: The actual output produced by the model.
125        :returns: A float representing the computed metric value.
126        """

Compute the metric that is specific to this Transform.

Parameters
  • target_output: The target/reference output.
  • model_output: The actual output produced by the model. :returns: A float representing the computed metric value.
class MeteorScore(SummarizationAccuracyMetric):
129class MeteorScore(SummarizationAccuracyMetric):
130    """This Transform augments an input record with the METEOR metric, computed from target and model outputs.
131
132    METEOR is a metric for text similarity between the machine-produced summary
133    and human-produced reference summaries.
134    Unigrams can be matched based on their surface forms, stemmed forms,
135    and meanings; furthermore, METEOR can be easily extended to include more
136    advanced matching strategies. Once all generalized unigram matches
137    between the two strings have been found, METEOR computes a score for
138    this matching using a combination of unigram-precision, unigram-recall, and
139    a measure of fragmentation that is designed to directly capture how
140    well-ordered the matched words in the machine translation are in relation
141    to the reference.
142    """
143
144    def __init__(
145        self,
146        target_output_keys: Optional[List[str]],
147        model_output_keys: List[str],
148        output_keys: List[str],
149        allow_duplicate_input_keys: bool,
150        target_output_keys_provider: Optional[str] = None,
151        load_modules: bool = True,
152    ):
153        """MeteorScore initializer.
154
155        :param target_output_keys: The keys corresponding to target outputs. If this is
156            set to None, then we will use `target_output_keys_provider` to get the
157            list of target outputs.
158        :param model_output_keys: The keys corresponding to model outputs.
159        :param output_keys: The output keys for this Transform, which correspond
160            to the Meteor scores that get computed.
161        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
162            `target_output_keys` and `model_output_keys`. This parameter is usually
163            False.
164        :param target_output_keys_provider: The key corresponding to a list of target
165            outputs. Will only be used if `target_output_keys` is set to None.
166        :param load_modules: Whether to load the meteor helper modules.
167        """
168        super().__init__(
169            target_output_keys,
170            model_output_keys,
171            output_keys,
172            allow_duplicate_input_keys,
173            target_output_keys_provider,
174            # The first instance of this class that gets created will
175            # load the helper modules, so copies of this instance
176            # need not load them again.
177            load_modules=False,
178        )
179        if load_modules:
180            MeteorScore._load_modules()
181
182    @staticmethod
183    def _load_modules() -> None:  # pragma: no cover
184        """Load helper modules required by meteor metric.
185
186        :returns: None
187        """
188        nltk.download("wordnet")
189        nltk.download("punkt")
190        nltk.download("punkt_tab")
191        nltk.download("omw-1.4")
192
193    def compute_metric(self, target_output: str, model_output: str) -> float:
194        """Compute the Meteor metric.
195
196        :param target_output: The target/reference output.
197        :param model_output: The actual output produced by the model.
198        :returns: The meteor metric value.
199        """
200        return meteor_score.single_meteor_score(
201            reference=word_tokenize(target_output),
202            hypothesis=word_tokenize(model_output),
203        )

This Transform augments an input record with the METEOR metric, computed from target and model outputs.

METEOR is a metric for text similarity between the machine-produced summary and human-produced reference summaries. Unigrams can be matched based on their surface forms, stemmed forms, and meanings; furthermore, METEOR can be easily extended to include more advanced matching strategies. Once all generalized unigram matches between the two strings have been found, METEOR computes a score for this matching using a combination of unigram-precision, unigram-recall, and a measure of fragmentation that is designed to directly capture how well-ordered the matched words in the machine translation are in relation to the reference.

MeteorScore( target_output_keys: Optional[List[str]], model_output_keys: List[str], output_keys: List[str], allow_duplicate_input_keys: bool, target_output_keys_provider: Optional[str] = None, load_modules: bool = True)
144    def __init__(
145        self,
146        target_output_keys: Optional[List[str]],
147        model_output_keys: List[str],
148        output_keys: List[str],
149        allow_duplicate_input_keys: bool,
150        target_output_keys_provider: Optional[str] = None,
151        load_modules: bool = True,
152    ):
153        """MeteorScore initializer.
154
155        :param target_output_keys: The keys corresponding to target outputs. If this is
156            set to None, then we will use `target_output_keys_provider` to get the
157            list of target outputs.
158        :param model_output_keys: The keys corresponding to model outputs.
159        :param output_keys: The output keys for this Transform, which correspond
160            to the Meteor scores that get computed.
161        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
162            `target_output_keys` and `model_output_keys`. This parameter is usually
163            False.
164        :param target_output_keys_provider: The key corresponding to a list of target
165            outputs. Will only be used if `target_output_keys` is set to None.
166        :param load_modules: Whether to load the meteor helper modules.
167        """
168        super().__init__(
169            target_output_keys,
170            model_output_keys,
171            output_keys,
172            allow_duplicate_input_keys,
173            target_output_keys_provider,
174            # The first instance of this class that gets created will
175            # load the helper modules, so copies of this instance
176            # need not load them again.
177            load_modules=False,
178        )
179        if load_modules:
180            MeteorScore._load_modules()

MeteorScore initializer.

Parameters
  • target_output_keys: The keys corresponding to target outputs. If this is set to None, then we will use target_output_keys_provider to get the list of target outputs.
  • model_output_keys: The keys corresponding to model outputs.
  • output_keys: The output keys for this Transform, which correspond to the Meteor scores that get computed.
  • allow_duplicate_input_keys: Whether to allow duplicate keys in target_output_keys and model_output_keys. This parameter is usually False.
  • target_output_keys_provider: The key corresponding to a list of target outputs. Will only be used if target_output_keys is set to None.
  • load_modules: Whether to load the meteor helper modules.
def compute_metric(self, target_output: str, model_output: str) -> float:
193    def compute_metric(self, target_output: str, model_output: str) -> float:
194        """Compute the Meteor metric.
195
196        :param target_output: The target/reference output.
197        :param model_output: The actual output produced by the model.
198        :returns: The meteor metric value.
199        """
200        return meteor_score.single_meteor_score(
201            reference=word_tokenize(target_output),
202            hypothesis=word_tokenize(model_output),
203        )

Compute the Meteor metric.

Parameters
  • target_output: The target/reference output.
  • model_output: The actual output produced by the model. :returns: The meteor metric value.
class RougeScore(SummarizationAccuracyMetric):
206class RougeScore(SummarizationAccuracyMetric):
207    """This transform augments an input record with the ROUGE score, computed from target and model outputs.
208
209    The ROUGE-N, where N=[1,2,L], score is a standard metric for summarization quality.
210    It computes the word overlap between the reference and model summary. Given that this metric is based on simple
211    word overlap statistics, it works best for extractive summaries.
212    Note that if we rephrase the summary without changing its meaning the ROUGE-N score will drop.
213
214    Reference: https://huggingface.co/spaces/evaluate-metric/rouge
215    """
216
217    def __init__(
218        self,
219        target_output_keys: Optional[List[str]],
220        model_output_keys: List[str],
221        output_keys: List[str],
222        allow_duplicate_input_keys: bool,
223        target_output_keys_provider: Optional[str] = None,
224        rouge_type: str = ROUGE_2,
225        use_stemmer: bool = True,
226    ):
227        """RougeScore initializer.
228
229        :param target_output_keys: The keys corresponding to target outputs. If this is
230            set to None, then we will use `target_output_keys_provider` to get the
231            list of target outputs.
232        :param model_output_keys: The keys corresponding to model outputs.
233        :param output_keys: The output keys for this Transform, which correspond
234            to the Rouge scores that get computed.
235        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
236            `target_output_keys` and `model_output_keys`. This parameter is usually
237            False.
238        :param target_output_keys_provider: The key corresponding to a list of target
239            outputs. Will only be used if `target_output_keys` is set to None.
240        :param rouge_type: Which ROUGE type to use (1, 2, L).
241        :param use_stemmer: Whether to use a stemmer for ROUGE.
242        """
243        super().__init__(
244            target_output_keys,
245            model_output_keys,
246            output_keys,
247            allow_duplicate_input_keys,
248            target_output_keys_provider,
249            rouge_type=rouge_type,
250            use_stemmer=use_stemmer,
251        )
252        self.rouge_type = rouge_type
253        self.use_stemmer = use_stemmer
254        self.rouge_metric = hf_evaluate.load("rouge")
255
256    def compute_metric(self, target_output: str, model_output: str) -> float:
257        """Compute the ROUGE metric.
258
259        :param target_output: The target/reference output.
260        :param model_output: The actual output produced by the model.
261        :returns: The ROUGE metric value.
262        """
263        return self.rouge_metric.compute(
264            predictions=[model_output],
265            references=[target_output],
266            use_stemmer=self.use_stemmer,
267            rouge_types=[self.rouge_type],
268        )[self.rouge_type]

This transform augments an input record with the ROUGE score, computed from target and model outputs.

The ROUGE-N, where N=[1,2,L], score is a standard metric for summarization quality. It computes the word overlap between the reference and model summary. Given that this metric is based on simple word overlap statistics, it works best for extractive summaries. Note that if we rephrase the summary without changing its meaning the ROUGE-N score will drop.

Reference: https://huggingface.co/spaces/evaluate-metric/rouge

RougeScore( target_output_keys: Optional[List[str]], model_output_keys: List[str], output_keys: List[str], allow_duplicate_input_keys: bool, target_output_keys_provider: Optional[str] = None, rouge_type: str = 'rouge2', use_stemmer: bool = True)
217    def __init__(
218        self,
219        target_output_keys: Optional[List[str]],
220        model_output_keys: List[str],
221        output_keys: List[str],
222        allow_duplicate_input_keys: bool,
223        target_output_keys_provider: Optional[str] = None,
224        rouge_type: str = ROUGE_2,
225        use_stemmer: bool = True,
226    ):
227        """RougeScore initializer.
228
229        :param target_output_keys: The keys corresponding to target outputs. If this is
230            set to None, then we will use `target_output_keys_provider` to get the
231            list of target outputs.
232        :param model_output_keys: The keys corresponding to model outputs.
233        :param output_keys: The output keys for this Transform, which correspond
234            to the Rouge scores that get computed.
235        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
236            `target_output_keys` and `model_output_keys`. This parameter is usually
237            False.
238        :param target_output_keys_provider: The key corresponding to a list of target
239            outputs. Will only be used if `target_output_keys` is set to None.
240        :param rouge_type: Which ROUGE type to use (1, 2, L).
241        :param use_stemmer: Whether to use a stemmer for ROUGE.
242        """
243        super().__init__(
244            target_output_keys,
245            model_output_keys,
246            output_keys,
247            allow_duplicate_input_keys,
248            target_output_keys_provider,
249            rouge_type=rouge_type,
250            use_stemmer=use_stemmer,
251        )
252        self.rouge_type = rouge_type
253        self.use_stemmer = use_stemmer
254        self.rouge_metric = hf_evaluate.load("rouge")

RougeScore initializer.

Parameters
  • target_output_keys: The keys corresponding to target outputs. If this is set to None, then we will use target_output_keys_provider to get the list of target outputs.
  • model_output_keys: The keys corresponding to model outputs.
  • output_keys: The output keys for this Transform, which correspond to the Rouge scores that get computed.
  • allow_duplicate_input_keys: Whether to allow duplicate keys in target_output_keys and model_output_keys. This parameter is usually False.
  • target_output_keys_provider: The key corresponding to a list of target outputs. Will only be used if target_output_keys is set to None.
  • rouge_type: Which ROUGE type to use (1, 2, L).
  • use_stemmer: Whether to use a stemmer for ROUGE.
rouge_type
use_stemmer
rouge_metric
def compute_metric(self, target_output: str, model_output: str) -> float:
256    def compute_metric(self, target_output: str, model_output: str) -> float:
257        """Compute the ROUGE metric.
258
259        :param target_output: The target/reference output.
260        :param model_output: The actual output produced by the model.
261        :returns: The ROUGE metric value.
262        """
263        return self.rouge_metric.compute(
264            predictions=[model_output],
265            references=[target_output],
266            use_stemmer=self.use_stemmer,
267            rouge_types=[self.rouge_type],
268        )[self.rouge_type]

Compute the ROUGE metric.

Parameters
  • target_output: The target/reference output.
  • model_output: The actual output produced by the model. :returns: The ROUGE metric value.
class BertScore(SummarizationAccuracyMetric):
271class BertScore(SummarizationAccuracyMetric):
272    """This transform augments an input record with the BERT score, computed from target and model outputs.
273
274    BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences
275    under a learned model, typically, from the BERT family.
276    This score may lead to increased flexibility compared to ROUGE and METEOR in terms of rephrasing since
277    semantically similar sentences are (typically) embedded similarly.
278
279    See https://huggingface.co/spaces/evaluate-metric/bertscore
280    """
281
282    def __init__(
283        self,
284        target_output_keys: Optional[List[str]],
285        model_output_keys: List[str],
286        output_keys: List[str],
287        allow_duplicate_input_keys: bool,
288        target_output_keys_provider: Optional[str] = None,
289        bertscore_model: Union[BertscoreHelperModel, ActorHandle] = BertscoreHelperModel(BERTSCORE_DEFAULT_MODEL),
290    ):
291        """BertScore initializer.
292
293        :param target_output_keys: The keys corresponding to target outputs. If this is
294            set to None, then we will use `target_output_keys_provider` to get the
295            list of target outputs.
296        :param model_output_keys: The keys corresponding to model outputs.
297        :param output_keys: The output keys for this Transform, which correspond
298            to the BERT_SCORES that get computed.
299        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
300            `target_output_keys` and `model_output_keys`. This parameter is usually
301            False.
302        :param target_output_keys_provider: The key corresponding to a list of target
303            outputs. Will only be used if `target_output_keys` is set to None.
304        :param bertscore_model: A BertscoreHelperModel instance or a Ray actor handle for a BertscoreHelperModel.
305            If no model is provided, the parameter will be set to the default BertscoreHelperModel
306        """
307        super().__init__(
308            target_output_keys,
309            model_output_keys,
310            output_keys,
311            allow_duplicate_input_keys,
312            target_output_keys_provider,
313            bertscore_model,
314        )
315        self.bertscore_model = bertscore_model
316
317    def compute_metric(self, target_output: str, model_output: str) -> float:
318        """Compute the BERTScore metric.
319
320        :param target_output: The target/reference output.
321        :param model_output: The actual output produced by the model.
322        :returns: The BERT metric value.
323        """
324        if isinstance(self.bertscore_model, BertscoreHelperModel):
325            return self.bertscore_model.get_helper_scores(target_output, model_output)
326        else:
327            return ray.get(  # type: ignore[return-value]
328                self.bertscore_model.get_helper_scores.remote(target_output, model_output)  # type: ignore[union-attr]
329            )

This transform augments an input record with the BERT score, computed from target and model outputs.

BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences under a learned model, typically, from the BERT family. This score may lead to increased flexibility compared to ROUGE and METEOR in terms of rephrasing since semantically similar sentences are (typically) embedded similarly.

See https://huggingface.co/spaces/evaluate-metric/bertscore

BertScore( target_output_keys: Optional[List[str]], model_output_keys: List[str], output_keys: List[str], allow_duplicate_input_keys: bool, target_output_keys_provider: Optional[str] = None, bertscore_model: Union[fmeval.eval_algorithms.helper_models.helper_model.BertscoreHelperModel, ray.actor.ActorHandle] = <fmeval.eval_algorithms.helper_models.helper_model.BertscoreHelperModel object>)
282    def __init__(
283        self,
284        target_output_keys: Optional[List[str]],
285        model_output_keys: List[str],
286        output_keys: List[str],
287        allow_duplicate_input_keys: bool,
288        target_output_keys_provider: Optional[str] = None,
289        bertscore_model: Union[BertscoreHelperModel, ActorHandle] = BertscoreHelperModel(BERTSCORE_DEFAULT_MODEL),
290    ):
291        """BertScore initializer.
292
293        :param target_output_keys: The keys corresponding to target outputs. If this is
294            set to None, then we will use `target_output_keys_provider` to get the
295            list of target outputs.
296        :param model_output_keys: The keys corresponding to model outputs.
297        :param output_keys: The output keys for this Transform, which correspond
298            to the BERT_SCORES that get computed.
299        :param allow_duplicate_input_keys: Whether to allow duplicate keys in
300            `target_output_keys` and `model_output_keys`. This parameter is usually
301            False.
302        :param target_output_keys_provider: The key corresponding to a list of target
303            outputs. Will only be used if `target_output_keys` is set to None.
304        :param bertscore_model: A BertscoreHelperModel instance or a Ray actor handle for a BertscoreHelperModel.
305            If no model is provided, the parameter will be set to the default BertscoreHelperModel
306        """
307        super().__init__(
308            target_output_keys,
309            model_output_keys,
310            output_keys,
311            allow_duplicate_input_keys,
312            target_output_keys_provider,
313            bertscore_model,
314        )
315        self.bertscore_model = bertscore_model

BertScore initializer.

Parameters
  • target_output_keys: The keys corresponding to target outputs. If this is set to None, then we will use target_output_keys_provider to get the list of target outputs.
  • model_output_keys: The keys corresponding to model outputs.
  • output_keys: The output keys for this Transform, which correspond to the BERT_SCORES that get computed.
  • allow_duplicate_input_keys: Whether to allow duplicate keys in target_output_keys and model_output_keys. This parameter is usually False.
  • target_output_keys_provider: The key corresponding to a list of target outputs. Will only be used if target_output_keys is set to None.
  • bertscore_model: A BertscoreHelperModel instance or a Ray actor handle for a BertscoreHelperModel. If no model is provided, the parameter will be set to the default BertscoreHelperModel
bertscore_model
def compute_metric(self, target_output: str, model_output: str) -> float:
317    def compute_metric(self, target_output: str, model_output: str) -> float:
318        """Compute the BERTScore metric.
319
320        :param target_output: The target/reference output.
321        :param model_output: The actual output produced by the model.
322        :returns: The BERT metric value.
323        """
324        if isinstance(self.bertscore_model, BertscoreHelperModel):
325            return self.bertscore_model.get_helper_scores(target_output, model_output)
326        else:
327            return ray.get(  # type: ignore[return-value]
328                self.bertscore_model.get_helper_scores.remote(target_output, model_output)  # type: ignore[union-attr]
329            )

Compute the BERTScore metric.

Parameters
  • target_output: The target/reference output.
  • model_output: The actual output produced by the model. :returns: The BERT metric value.