fmeval.eval_algorithms.helper_models.helper_model

  1import evaluate as hf_evaluate
  2import torch
  3import transformers
  4
  5from enum import Enum
  6from abc import ABC, abstractmethod
  7from typing import Any, cast, Dict, List
  8from transformers import pipeline, AutoConfig
  9
 10TOXIGEN_SCORE_NAME = "toxicity"
 11
 12DETOXIFY_SCORE_TOXICITY = "toxicity"
 13DETOXIFY_SCORE_SEVERE_TOXICITY = "severe_toxicity"
 14DETOXIFY_SCORE_OBSCENE = "obscene"
 15DETOXIFY_SCORE_IDENTITY_ATTACK = "identity_attack"
 16DETOXIFY_SCORE_INSULT = "insult"
 17DETOXIFY_SCORE_THREAT = "threat"
 18DETOXIFY_SCORE_SEXUAL_EXPLICIT = "sexual_explicit"
 19DETOXIFY_SCORE_NAMES = [
 20    DETOXIFY_SCORE_TOXICITY,
 21    DETOXIFY_SCORE_SEVERE_TOXICITY,
 22    DETOXIFY_SCORE_OBSCENE,
 23    DETOXIFY_SCORE_IDENTITY_ATTACK,
 24    DETOXIFY_SCORE_INSULT,
 25    DETOXIFY_SCORE_THREAT,
 26    DETOXIFY_SCORE_SEXUAL_EXPLICIT,
 27]
 28
 29
 30class BaseHelperModel(ABC):
 31    """
 32    Base class for 3P helper model invoker. Note: These Helper models are inherently
 33    Machine learning models being used by Evaluation algorithms.
 34    """
 35
 36    @abstractmethod
 37    def get_helper_scores(self, text_input: str) -> Any:
 38        """
 39        Method to invoke helper model
 40        :param text_input: model text input
 41        :returns: model output
 42        """
 43
 44
 45class ToxigenHelperModel(BaseHelperModel):
 46    """
 47    Helper model for toxigen model: https://huggingface.co/tomh/toxigen_roberta/tree/main
 48    """
 49
 50    TOXIGEN_MODEL_NAME = "tomh/toxigen_roberta"
 51
 52    def __init__(self):
 53        """
 54        Constructor to locally load the helper model for inference.
 55        """
 56        self._model = pipeline("text-classification", model=self.TOXIGEN_MODEL_NAME)
 57
 58    def __reduce__(self):
 59        """Serializer method so that instances of this class can be made into shared resources."""
 60        return self.__class__, ()
 61
 62    def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:  # type: ignore[override]
 63        """
 64        Method to get scores from ToxigenHelper
 65        :param text_input: list of text inputs for the model
 66        :returns: dict with key as score name and value being list of scores for text inputs
 67
 68        Note: Toxigen scores are for label: LABEL_1
 69        """
 70        inference_output = self._model(text_input)
 71        result = {
 72            TOXIGEN_SCORE_NAME: [x["score"] if x["label"] == "LABEL_1" else 1.0 - x["score"] for x in inference_output]
 73        }
 74        return result
 75
 76    @staticmethod
 77    def get_score_names() -> List[str]:
 78        """
 79        Util method to return name of scores generated by helper model
 80        :returns: List of score names
 81        """
 82        return [TOXIGEN_SCORE_NAME]
 83
 84
 85class DetoxifyHelperModel(BaseHelperModel):
 86    """
 87    Helper model for Detoxify: https://github.com/unitaryai/detoxify
 88
 89    Note: we load the unbiased model directly from the state dict due to dependency conflicts between detoxify and
 90    transformers libraries.
 91
 92    TODO: To be switched to consuming HF model once consistency issue is resolved:
 93    https://huggingface.co/unitary/unbiased-toxic-roberta. This will allow removing detoxify PyPI as a dependency,
 94    update transformers version we are consuming.
 95    """
 96
 97    UNBIASED_MODEL_URL = (
 98        "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt"
 99    )
100
101    def __init__(self):
102        """
103        Constructor to locally load the helper model for inference.
104        """
105        state_dict = torch.hub.load_state_dict_from_url(self.UNBIASED_MODEL_URL, map_location="cpu")
106        config = state_dict["config"]["arch"]["args"]
107        self._model = (
108            getattr(transformers, config["model_name"])
109            .from_pretrained(
110                pretrained_model_name_or_path=None,
111                config=AutoConfig.from_pretrained(config["model_type"], num_labels=config["num_classes"]),
112                state_dict=state_dict["state_dict"],
113                local_files_only=False,
114            )
115            .to("cpu")
116        )
117        self._tokenizer = getattr(transformers, config["tokenizer_name"]).from_pretrained(config["model_type"])
118
119    def __reduce__(self):
120        """Serializer method so that instances of this class can be made into shared resources."""
121        return self.__class__, ()
122
123    def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:  # type: ignore[override]
124        """
125        Method to get scores from DetoxifyHelper
126        :param text_input: list of text inputs for the model
127        :returns: dict with keys as score name and value being list of scores for text inputs
128        """
129        inputs = self._tokenizer(text_input, return_tensors="pt", truncation=True, padding=True).to(self._model.device)
130        scores = torch.sigmoid(self._model(**inputs)[0]).cpu().detach().numpy()
131        return {
132            score_name: [score[i].tolist() for score in scores]
133            for i, score_name in enumerate(DetoxifyHelperModel.get_score_names())
134        }
135
136    @staticmethod
137    def get_score_names() -> List[str]:
138        """
139        Util method to return name of scores generated by helper model
140        :returns: List of score names
141        """
142        return DETOXIFY_SCORE_NAMES
143
144
145class BertscoreHelperModel(BaseHelperModel):
146    """
147    BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences
148    under a (learned) model, typically, from the BERT family.
149    This score may lead to increased flexibility compared to rouge and METEOR in terms of rephrasing since
150    semantically similar sentences are (typically) embedded similarly.
151    https://huggingface.co/spaces/evaluate-metric/bertscore
152    Note: we specify that this Ray actor requires num_cpus=1 in order to limit the number of concurrently
153    running tasks or actors to avoid out of memory issues.
154    See https://docs.ray.io/en/latest/ray-core/patterns/limit-running-tasks.html#core-patterns-limit-running-tasks
155    for a detailed explanation.
156    """
157
158    def __init__(self, model_type: str):  # pragma: no cover
159        """
160        Default constructor
161        :param model_type: Model type to be used for bertscore
162        """
163        self._bertscore = hf_evaluate.load("bertscore")
164        self._model_type = model_type
165
166    def __reduce__(self):
167        """Serializer method so that instances of this class can be made into shared resources."""
168        return self.__class__, (self._model_type,)
169
170    def get_helper_scores(self, target_output: str, model_output: str) -> float:  # type: ignore[override]
171        """
172        Method to invoke the concerned model and get bertscore
173        :param target_output: Reference text
174        :model_output: Model prediction text
175        """
176        # Note: the following code is covered by unit tests,
177        # but since it gets executed by Ray, Mypy marks it
178        # as not covered.
179        return self._bertscore.compute(  # pragma: no cover
180            predictions=[model_output],
181            references=[target_output],
182            model_type=self._model_type,
183        )["f1"][0]
184
185
186class BertscoreHelperModelTypes(str, Enum):
187    """This class holds the names of all the allowed models for computing the BERTScore."""
188
189    MICROSOFT_DEBERTA_MODEL = "microsoft/deberta-xlarge-mnli"
190    ROBERTA_MODEL = "roberta-large-mnli"
191
192    @classmethod
193    def model_is_allowed(cls, model_name: str) -> bool:
194        """
195        Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore.
196        """
197        for elem in iter(cls):
198            # Because this is a (str, Enum), need cast to keep mypy happy:
199            if cast(BertscoreHelperModelTypes, elem).value == model_name:
200                return True
201        return False
202
203    @classmethod
204    def model_list(cls) -> List[str]:
205        """
206        Return a list of all the allowed models for computing BERTScore.
207        """
208        # Because this is a (str, Enum), need cast to keep mypy happy:
209        return [cast(BertscoreHelperModelTypes, elem).value for elem in iter(cls)]
TOXIGEN_SCORE_NAME = 'toxicity'
DETOXIFY_SCORE_TOXICITY = 'toxicity'
DETOXIFY_SCORE_SEVERE_TOXICITY = 'severe_toxicity'
DETOXIFY_SCORE_OBSCENE = 'obscene'
DETOXIFY_SCORE_IDENTITY_ATTACK = 'identity_attack'
DETOXIFY_SCORE_INSULT = 'insult'
DETOXIFY_SCORE_THREAT = 'threat'
DETOXIFY_SCORE_SEXUAL_EXPLICIT = 'sexual_explicit'
DETOXIFY_SCORE_NAMES = ['toxicity', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat', 'sexual_explicit']
class BaseHelperModel(abc.ABC):
31class BaseHelperModel(ABC):
32    """
33    Base class for 3P helper model invoker. Note: These Helper models are inherently
34    Machine learning models being used by Evaluation algorithms.
35    """
36
37    @abstractmethod
38    def get_helper_scores(self, text_input: str) -> Any:
39        """
40        Method to invoke helper model
41        :param text_input: model text input
42        :returns: model output
43        """

Base class for 3P helper model invoker. Note: These Helper models are inherently Machine learning models being used by Evaluation algorithms.

@abstractmethod
def get_helper_scores(self, text_input: str) -> Any:
37    @abstractmethod
38    def get_helper_scores(self, text_input: str) -> Any:
39        """
40        Method to invoke helper model
41        :param text_input: model text input
42        :returns: model output
43        """

Method to invoke helper model

Parameters
  • text_input: model text input :returns: model output
class ToxigenHelperModel(BaseHelperModel):
46class ToxigenHelperModel(BaseHelperModel):
47    """
48    Helper model for toxigen model: https://huggingface.co/tomh/toxigen_roberta/tree/main
49    """
50
51    TOXIGEN_MODEL_NAME = "tomh/toxigen_roberta"
52
53    def __init__(self):
54        """
55        Constructor to locally load the helper model for inference.
56        """
57        self._model = pipeline("text-classification", model=self.TOXIGEN_MODEL_NAME)
58
59    def __reduce__(self):
60        """Serializer method so that instances of this class can be made into shared resources."""
61        return self.__class__, ()
62
63    def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:  # type: ignore[override]
64        """
65        Method to get scores from ToxigenHelper
66        :param text_input: list of text inputs for the model
67        :returns: dict with key as score name and value being list of scores for text inputs
68
69        Note: Toxigen scores are for label: LABEL_1
70        """
71        inference_output = self._model(text_input)
72        result = {
73            TOXIGEN_SCORE_NAME: [x["score"] if x["label"] == "LABEL_1" else 1.0 - x["score"] for x in inference_output]
74        }
75        return result
76
77    @staticmethod
78    def get_score_names() -> List[str]:
79        """
80        Util method to return name of scores generated by helper model
81        :returns: List of score names
82        """
83        return [TOXIGEN_SCORE_NAME]
ToxigenHelperModel()
53    def __init__(self):
54        """
55        Constructor to locally load the helper model for inference.
56        """
57        self._model = pipeline("text-classification", model=self.TOXIGEN_MODEL_NAME)

Constructor to locally load the helper model for inference.

TOXIGEN_MODEL_NAME = 'tomh/toxigen_roberta'
def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:
63    def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:  # type: ignore[override]
64        """
65        Method to get scores from ToxigenHelper
66        :param text_input: list of text inputs for the model
67        :returns: dict with key as score name and value being list of scores for text inputs
68
69        Note: Toxigen scores are for label: LABEL_1
70        """
71        inference_output = self._model(text_input)
72        result = {
73            TOXIGEN_SCORE_NAME: [x["score"] if x["label"] == "LABEL_1" else 1.0 - x["score"] for x in inference_output]
74        }
75        return result

Method to get scores from ToxigenHelper

Parameters
  • text_input: list of text inputs for the model :returns: dict with key as score name and value being list of scores for text inputs

Note: Toxigen scores are for label: LABEL_1

@staticmethod
def get_score_names() -> List[str]:
77    @staticmethod
78    def get_score_names() -> List[str]:
79        """
80        Util method to return name of scores generated by helper model
81        :returns: List of score names
82        """
83        return [TOXIGEN_SCORE_NAME]

Util method to return name of scores generated by helper model :returns: List of score names

class DetoxifyHelperModel(BaseHelperModel):
 86class DetoxifyHelperModel(BaseHelperModel):
 87    """
 88    Helper model for Detoxify: https://github.com/unitaryai/detoxify
 89
 90    Note: we load the unbiased model directly from the state dict due to dependency conflicts between detoxify and
 91    transformers libraries.
 92
 93    TODO: To be switched to consuming HF model once consistency issue is resolved:
 94    https://huggingface.co/unitary/unbiased-toxic-roberta. This will allow removing detoxify PyPI as a dependency,
 95    update transformers version we are consuming.
 96    """
 97
 98    UNBIASED_MODEL_URL = (
 99        "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt"
100    )
101
102    def __init__(self):
103        """
104        Constructor to locally load the helper model for inference.
105        """
106        state_dict = torch.hub.load_state_dict_from_url(self.UNBIASED_MODEL_URL, map_location="cpu")
107        config = state_dict["config"]["arch"]["args"]
108        self._model = (
109            getattr(transformers, config["model_name"])
110            .from_pretrained(
111                pretrained_model_name_or_path=None,
112                config=AutoConfig.from_pretrained(config["model_type"], num_labels=config["num_classes"]),
113                state_dict=state_dict["state_dict"],
114                local_files_only=False,
115            )
116            .to("cpu")
117        )
118        self._tokenizer = getattr(transformers, config["tokenizer_name"]).from_pretrained(config["model_type"])
119
120    def __reduce__(self):
121        """Serializer method so that instances of this class can be made into shared resources."""
122        return self.__class__, ()
123
124    def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:  # type: ignore[override]
125        """
126        Method to get scores from DetoxifyHelper
127        :param text_input: list of text inputs for the model
128        :returns: dict with keys as score name and value being list of scores for text inputs
129        """
130        inputs = self._tokenizer(text_input, return_tensors="pt", truncation=True, padding=True).to(self._model.device)
131        scores = torch.sigmoid(self._model(**inputs)[0]).cpu().detach().numpy()
132        return {
133            score_name: [score[i].tolist() for score in scores]
134            for i, score_name in enumerate(DetoxifyHelperModel.get_score_names())
135        }
136
137    @staticmethod
138    def get_score_names() -> List[str]:
139        """
140        Util method to return name of scores generated by helper model
141        :returns: List of score names
142        """
143        return DETOXIFY_SCORE_NAMES

Helper model for Detoxify: https://github.com/unitaryai/detoxify

Note: we load the unbiased model directly from the state dict due to dependency conflicts between detoxify and transformers libraries.

TODO: To be switched to consuming HF model once consistency issue is resolved: https://huggingface.co/unitary/unbiased-toxic-roberta. This will allow removing detoxify PyPI as a dependency, update transformers version we are consuming.

DetoxifyHelperModel()
102    def __init__(self):
103        """
104        Constructor to locally load the helper model for inference.
105        """
106        state_dict = torch.hub.load_state_dict_from_url(self.UNBIASED_MODEL_URL, map_location="cpu")
107        config = state_dict["config"]["arch"]["args"]
108        self._model = (
109            getattr(transformers, config["model_name"])
110            .from_pretrained(
111                pretrained_model_name_or_path=None,
112                config=AutoConfig.from_pretrained(config["model_type"], num_labels=config["num_classes"]),
113                state_dict=state_dict["state_dict"],
114                local_files_only=False,
115            )
116            .to("cpu")
117        )
118        self._tokenizer = getattr(transformers, config["tokenizer_name"]).from_pretrained(config["model_type"])

Constructor to locally load the helper model for inference.

UNBIASED_MODEL_URL = 'https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt'
def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:
124    def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]:  # type: ignore[override]
125        """
126        Method to get scores from DetoxifyHelper
127        :param text_input: list of text inputs for the model
128        :returns: dict with keys as score name and value being list of scores for text inputs
129        """
130        inputs = self._tokenizer(text_input, return_tensors="pt", truncation=True, padding=True).to(self._model.device)
131        scores = torch.sigmoid(self._model(**inputs)[0]).cpu().detach().numpy()
132        return {
133            score_name: [score[i].tolist() for score in scores]
134            for i, score_name in enumerate(DetoxifyHelperModel.get_score_names())
135        }

Method to get scores from DetoxifyHelper

Parameters
  • text_input: list of text inputs for the model :returns: dict with keys as score name and value being list of scores for text inputs
@staticmethod
def get_score_names() -> List[str]:
137    @staticmethod
138    def get_score_names() -> List[str]:
139        """
140        Util method to return name of scores generated by helper model
141        :returns: List of score names
142        """
143        return DETOXIFY_SCORE_NAMES

Util method to return name of scores generated by helper model :returns: List of score names

class BertscoreHelperModel(BaseHelperModel):
146class BertscoreHelperModel(BaseHelperModel):
147    """
148    BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences
149    under a (learned) model, typically, from the BERT family.
150    This score may lead to increased flexibility compared to rouge and METEOR in terms of rephrasing since
151    semantically similar sentences are (typically) embedded similarly.
152    https://huggingface.co/spaces/evaluate-metric/bertscore
153    Note: we specify that this Ray actor requires num_cpus=1 in order to limit the number of concurrently
154    running tasks or actors to avoid out of memory issues.
155    See https://docs.ray.io/en/latest/ray-core/patterns/limit-running-tasks.html#core-patterns-limit-running-tasks
156    for a detailed explanation.
157    """
158
159    def __init__(self, model_type: str):  # pragma: no cover
160        """
161        Default constructor
162        :param model_type: Model type to be used for bertscore
163        """
164        self._bertscore = hf_evaluate.load("bertscore")
165        self._model_type = model_type
166
167    def __reduce__(self):
168        """Serializer method so that instances of this class can be made into shared resources."""
169        return self.__class__, (self._model_type,)
170
171    def get_helper_scores(self, target_output: str, model_output: str) -> float:  # type: ignore[override]
172        """
173        Method to invoke the concerned model and get bertscore
174        :param target_output: Reference text
175        :model_output: Model prediction text
176        """
177        # Note: the following code is covered by unit tests,
178        # but since it gets executed by Ray, Mypy marks it
179        # as not covered.
180        return self._bertscore.compute(  # pragma: no cover
181            predictions=[model_output],
182            references=[target_output],
183            model_type=self._model_type,
184        )["f1"][0]

BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences under a (learned) model, typically, from the BERT family. This score may lead to increased flexibility compared to rouge and METEOR in terms of rephrasing since semantically similar sentences are (typically) embedded similarly. https://huggingface.co/spaces/evaluate-metric/bertscore Note: we specify that this Ray actor requires num_cpus=1 in order to limit the number of concurrently running tasks or actors to avoid out of memory issues. See https://docs.ray.io/en/latest/ray-core/patterns/limit-running-tasks.html#core-patterns-limit-running-tasks for a detailed explanation.

BertscoreHelperModel(model_type: str)
159    def __init__(self, model_type: str):  # pragma: no cover
160        """
161        Default constructor
162        :param model_type: Model type to be used for bertscore
163        """
164        self._bertscore = hf_evaluate.load("bertscore")
165        self._model_type = model_type

Default constructor

Parameters
  • model_type: Model type to be used for bertscore
def get_helper_scores(self, target_output: str, model_output: str) -> float:
171    def get_helper_scores(self, target_output: str, model_output: str) -> float:  # type: ignore[override]
172        """
173        Method to invoke the concerned model and get bertscore
174        :param target_output: Reference text
175        :model_output: Model prediction text
176        """
177        # Note: the following code is covered by unit tests,
178        # but since it gets executed by Ray, Mypy marks it
179        # as not covered.
180        return self._bertscore.compute(  # pragma: no cover
181            predictions=[model_output],
182            references=[target_output],
183            model_type=self._model_type,
184        )["f1"][0]

Method to invoke the concerned model and get bertscore

Parameters
  • target_output: Reference text :model_output: Model prediction text
class BertscoreHelperModelTypes(builtins.str, enum.Enum):
187class BertscoreHelperModelTypes(str, Enum):
188    """This class holds the names of all the allowed models for computing the BERTScore."""
189
190    MICROSOFT_DEBERTA_MODEL = "microsoft/deberta-xlarge-mnli"
191    ROBERTA_MODEL = "roberta-large-mnli"
192
193    @classmethod
194    def model_is_allowed(cls, model_name: str) -> bool:
195        """
196        Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore.
197        """
198        for elem in iter(cls):
199            # Because this is a (str, Enum), need cast to keep mypy happy:
200            if cast(BertscoreHelperModelTypes, elem).value == model_name:
201                return True
202        return False
203
204    @classmethod
205    def model_list(cls) -> List[str]:
206        """
207        Return a list of all the allowed models for computing BERTScore.
208        """
209        # Because this is a (str, Enum), need cast to keep mypy happy:
210        return [cast(BertscoreHelperModelTypes, elem).value for elem in iter(cls)]

This class holds the names of all the allowed models for computing the BERTScore.

MICROSOFT_DEBERTA_MODEL = <BertscoreHelperModelTypes.MICROSOFT_DEBERTA_MODEL: 'microsoft/deberta-xlarge-mnli'>
ROBERTA_MODEL = <BertscoreHelperModelTypes.ROBERTA_MODEL: 'roberta-large-mnli'>
@classmethod
def model_is_allowed(cls, model_name: str) -> bool:
193    @classmethod
194    def model_is_allowed(cls, model_name: str) -> bool:
195        """
196        Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore.
197        """
198        for elem in iter(cls):
199            # Because this is a (str, Enum), need cast to keep mypy happy:
200            if cast(BertscoreHelperModelTypes, elem).value == model_name:
201                return True
202        return False

Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore.

@classmethod
def model_list(cls) -> List[str]:
204    @classmethod
205    def model_list(cls) -> List[str]:
206        """
207        Return a list of all the allowed models for computing BERTScore.
208        """
209        # Because this is a (str, Enum), need cast to keep mypy happy:
210        return [cast(BertscoreHelperModelTypes, elem).value for elem in iter(cls)]

Return a list of all the allowed models for computing BERTScore.