fmeval.eval_algorithms.helper_models.helper_model
1import evaluate as hf_evaluate 2import torch 3import transformers 4 5from enum import Enum 6from abc import ABC, abstractmethod 7from typing import Any, cast, Dict, List 8from transformers import pipeline, AutoConfig 9 10TOXIGEN_SCORE_NAME = "toxicity" 11 12DETOXIFY_SCORE_TOXICITY = "toxicity" 13DETOXIFY_SCORE_SEVERE_TOXICITY = "severe_toxicity" 14DETOXIFY_SCORE_OBSCENE = "obscene" 15DETOXIFY_SCORE_IDENTITY_ATTACK = "identity_attack" 16DETOXIFY_SCORE_INSULT = "insult" 17DETOXIFY_SCORE_THREAT = "threat" 18DETOXIFY_SCORE_SEXUAL_EXPLICIT = "sexual_explicit" 19DETOXIFY_SCORE_NAMES = [ 20 DETOXIFY_SCORE_TOXICITY, 21 DETOXIFY_SCORE_SEVERE_TOXICITY, 22 DETOXIFY_SCORE_OBSCENE, 23 DETOXIFY_SCORE_IDENTITY_ATTACK, 24 DETOXIFY_SCORE_INSULT, 25 DETOXIFY_SCORE_THREAT, 26 DETOXIFY_SCORE_SEXUAL_EXPLICIT, 27] 28 29 30class BaseHelperModel(ABC): 31 """ 32 Base class for 3P helper model invoker. Note: These Helper models are inherently 33 Machine learning models being used by Evaluation algorithms. 34 """ 35 36 @abstractmethod 37 def get_helper_scores(self, text_input: str) -> Any: 38 """ 39 Method to invoke helper model 40 :param text_input: model text input 41 :returns: model output 42 """ 43 44 45class ToxigenHelperModel(BaseHelperModel): 46 """ 47 Helper model for toxigen model: https://huggingface.co/tomh/toxigen_roberta/tree/main 48 """ 49 50 TOXIGEN_MODEL_NAME = "tomh/toxigen_roberta" 51 52 def __init__(self): 53 """ 54 Constructor to locally load the helper model for inference. 55 """ 56 self._model = pipeline("text-classification", model=self.TOXIGEN_MODEL_NAME) 57 58 def __reduce__(self): 59 """Serializer method so that instances of this class can be made into shared resources.""" 60 return self.__class__, () 61 62 def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]: # type: ignore[override] 63 """ 64 Method to get scores from ToxigenHelper 65 :param text_input: list of text inputs for the model 66 :returns: dict with key as score name and value being list of scores for text inputs 67 68 Note: Toxigen scores are for label: LABEL_1 69 """ 70 inference_output = self._model(text_input) 71 result = { 72 TOXIGEN_SCORE_NAME: [x["score"] if x["label"] == "LABEL_1" else 1.0 - x["score"] for x in inference_output] 73 } 74 return result 75 76 @staticmethod 77 def get_score_names() -> List[str]: 78 """ 79 Util method to return name of scores generated by helper model 80 :returns: List of score names 81 """ 82 return [TOXIGEN_SCORE_NAME] 83 84 85class DetoxifyHelperModel(BaseHelperModel): 86 """ 87 Helper model for Detoxify: https://github.com/unitaryai/detoxify 88 89 Note: we load the unbiased model directly from the state dict due to dependency conflicts between detoxify and 90 transformers libraries. 91 92 TODO: To be switched to consuming HF model once consistency issue is resolved: 93 https://huggingface.co/unitary/unbiased-toxic-roberta. This will allow removing detoxify PyPI as a dependency, 94 update transformers version we are consuming. 95 """ 96 97 UNBIASED_MODEL_URL = ( 98 "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt" 99 ) 100 101 def __init__(self): 102 """ 103 Constructor to locally load the helper model for inference. 104 """ 105 state_dict = torch.hub.load_state_dict_from_url(self.UNBIASED_MODEL_URL, map_location="cpu") 106 config = state_dict["config"]["arch"]["args"] 107 self._model = ( 108 getattr(transformers, config["model_name"]) 109 .from_pretrained( 110 pretrained_model_name_or_path=None, 111 config=AutoConfig.from_pretrained(config["model_type"], num_labels=config["num_classes"]), 112 state_dict=state_dict["state_dict"], 113 local_files_only=False, 114 ) 115 .to("cpu") 116 ) 117 self._tokenizer = getattr(transformers, config["tokenizer_name"]).from_pretrained(config["model_type"]) 118 119 def __reduce__(self): 120 """Serializer method so that instances of this class can be made into shared resources.""" 121 return self.__class__, () 122 123 def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]: # type: ignore[override] 124 """ 125 Method to get scores from DetoxifyHelper 126 :param text_input: list of text inputs for the model 127 :returns: dict with keys as score name and value being list of scores for text inputs 128 """ 129 inputs = self._tokenizer(text_input, return_tensors="pt", truncation=True, padding=True).to(self._model.device) 130 scores = torch.sigmoid(self._model(**inputs)[0]).cpu().detach().numpy() 131 return { 132 score_name: [score[i].tolist() for score in scores] 133 for i, score_name in enumerate(DetoxifyHelperModel.get_score_names()) 134 } 135 136 @staticmethod 137 def get_score_names() -> List[str]: 138 """ 139 Util method to return name of scores generated by helper model 140 :returns: List of score names 141 """ 142 return DETOXIFY_SCORE_NAMES 143 144 145class BertscoreHelperModel(BaseHelperModel): 146 """ 147 BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences 148 under a (learned) model, typically, from the BERT family. 149 This score may lead to increased flexibility compared to rouge and METEOR in terms of rephrasing since 150 semantically similar sentences are (typically) embedded similarly. 151 https://huggingface.co/spaces/evaluate-metric/bertscore 152 Note: we specify that this Ray actor requires num_cpus=1 in order to limit the number of concurrently 153 running tasks or actors to avoid out of memory issues. 154 See https://docs.ray.io/en/latest/ray-core/patterns/limit-running-tasks.html#core-patterns-limit-running-tasks 155 for a detailed explanation. 156 """ 157 158 def __init__(self, model_type: str): # pragma: no cover 159 """ 160 Default constructor 161 :param model_type: Model type to be used for bertscore 162 """ 163 self._bertscore = hf_evaluate.load("bertscore") 164 self._model_type = model_type 165 166 def __reduce__(self): 167 """Serializer method so that instances of this class can be made into shared resources.""" 168 return self.__class__, (self._model_type,) 169 170 def get_helper_scores(self, target_output: str, model_output: str) -> float: # type: ignore[override] 171 """ 172 Method to invoke the concerned model and get bertscore 173 :param target_output: Reference text 174 :model_output: Model prediction text 175 """ 176 # Note: the following code is covered by unit tests, 177 # but since it gets executed by Ray, Mypy marks it 178 # as not covered. 179 return self._bertscore.compute( # pragma: no cover 180 predictions=[model_output], 181 references=[target_output], 182 model_type=self._model_type, 183 )["f1"][0] 184 185 186class BertscoreHelperModelTypes(str, Enum): 187 """This class holds the names of all the allowed models for computing the BERTScore.""" 188 189 MICROSOFT_DEBERTA_MODEL = "microsoft/deberta-xlarge-mnli" 190 ROBERTA_MODEL = "roberta-large-mnli" 191 192 @classmethod 193 def model_is_allowed(cls, model_name: str) -> bool: 194 """ 195 Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore. 196 """ 197 for elem in iter(cls): 198 # Because this is a (str, Enum), need cast to keep mypy happy: 199 if cast(BertscoreHelperModelTypes, elem).value == model_name: 200 return True 201 return False 202 203 @classmethod 204 def model_list(cls) -> List[str]: 205 """ 206 Return a list of all the allowed models for computing BERTScore. 207 """ 208 # Because this is a (str, Enum), need cast to keep mypy happy: 209 return [cast(BertscoreHelperModelTypes, elem).value for elem in iter(cls)]
31class BaseHelperModel(ABC): 32 """ 33 Base class for 3P helper model invoker. Note: These Helper models are inherently 34 Machine learning models being used by Evaluation algorithms. 35 """ 36 37 @abstractmethod 38 def get_helper_scores(self, text_input: str) -> Any: 39 """ 40 Method to invoke helper model 41 :param text_input: model text input 42 :returns: model output 43 """
Base class for 3P helper model invoker. Note: These Helper models are inherently Machine learning models being used by Evaluation algorithms.
37 @abstractmethod 38 def get_helper_scores(self, text_input: str) -> Any: 39 """ 40 Method to invoke helper model 41 :param text_input: model text input 42 :returns: model output 43 """
Method to invoke helper model
Parameters
- text_input: model text input :returns: model output
46class ToxigenHelperModel(BaseHelperModel): 47 """ 48 Helper model for toxigen model: https://huggingface.co/tomh/toxigen_roberta/tree/main 49 """ 50 51 TOXIGEN_MODEL_NAME = "tomh/toxigen_roberta" 52 53 def __init__(self): 54 """ 55 Constructor to locally load the helper model for inference. 56 """ 57 self._model = pipeline("text-classification", model=self.TOXIGEN_MODEL_NAME) 58 59 def __reduce__(self): 60 """Serializer method so that instances of this class can be made into shared resources.""" 61 return self.__class__, () 62 63 def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]: # type: ignore[override] 64 """ 65 Method to get scores from ToxigenHelper 66 :param text_input: list of text inputs for the model 67 :returns: dict with key as score name and value being list of scores for text inputs 68 69 Note: Toxigen scores are for label: LABEL_1 70 """ 71 inference_output = self._model(text_input) 72 result = { 73 TOXIGEN_SCORE_NAME: [x["score"] if x["label"] == "LABEL_1" else 1.0 - x["score"] for x in inference_output] 74 } 75 return result 76 77 @staticmethod 78 def get_score_names() -> List[str]: 79 """ 80 Util method to return name of scores generated by helper model 81 :returns: List of score names 82 """ 83 return [TOXIGEN_SCORE_NAME]
Helper model for toxigen model: https://huggingface.co/tomh/toxigen_roberta/tree/main
53 def __init__(self): 54 """ 55 Constructor to locally load the helper model for inference. 56 """ 57 self._model = pipeline("text-classification", model=self.TOXIGEN_MODEL_NAME)
Constructor to locally load the helper model for inference.
63 def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]: # type: ignore[override] 64 """ 65 Method to get scores from ToxigenHelper 66 :param text_input: list of text inputs for the model 67 :returns: dict with key as score name and value being list of scores for text inputs 68 69 Note: Toxigen scores are for label: LABEL_1 70 """ 71 inference_output = self._model(text_input) 72 result = { 73 TOXIGEN_SCORE_NAME: [x["score"] if x["label"] == "LABEL_1" else 1.0 - x["score"] for x in inference_output] 74 } 75 return result
Method to get scores from ToxigenHelper
Parameters
- text_input: list of text inputs for the model :returns: dict with key as score name and value being list of scores for text inputs
Note: Toxigen scores are for label: LABEL_1
77 @staticmethod 78 def get_score_names() -> List[str]: 79 """ 80 Util method to return name of scores generated by helper model 81 :returns: List of score names 82 """ 83 return [TOXIGEN_SCORE_NAME]
Util method to return name of scores generated by helper model :returns: List of score names
86class DetoxifyHelperModel(BaseHelperModel): 87 """ 88 Helper model for Detoxify: https://github.com/unitaryai/detoxify 89 90 Note: we load the unbiased model directly from the state dict due to dependency conflicts between detoxify and 91 transformers libraries. 92 93 TODO: To be switched to consuming HF model once consistency issue is resolved: 94 https://huggingface.co/unitary/unbiased-toxic-roberta. This will allow removing detoxify PyPI as a dependency, 95 update transformers version we are consuming. 96 """ 97 98 UNBIASED_MODEL_URL = ( 99 "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt" 100 ) 101 102 def __init__(self): 103 """ 104 Constructor to locally load the helper model for inference. 105 """ 106 state_dict = torch.hub.load_state_dict_from_url(self.UNBIASED_MODEL_URL, map_location="cpu") 107 config = state_dict["config"]["arch"]["args"] 108 self._model = ( 109 getattr(transformers, config["model_name"]) 110 .from_pretrained( 111 pretrained_model_name_or_path=None, 112 config=AutoConfig.from_pretrained(config["model_type"], num_labels=config["num_classes"]), 113 state_dict=state_dict["state_dict"], 114 local_files_only=False, 115 ) 116 .to("cpu") 117 ) 118 self._tokenizer = getattr(transformers, config["tokenizer_name"]).from_pretrained(config["model_type"]) 119 120 def __reduce__(self): 121 """Serializer method so that instances of this class can be made into shared resources.""" 122 return self.__class__, () 123 124 def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]: # type: ignore[override] 125 """ 126 Method to get scores from DetoxifyHelper 127 :param text_input: list of text inputs for the model 128 :returns: dict with keys as score name and value being list of scores for text inputs 129 """ 130 inputs = self._tokenizer(text_input, return_tensors="pt", truncation=True, padding=True).to(self._model.device) 131 scores = torch.sigmoid(self._model(**inputs)[0]).cpu().detach().numpy() 132 return { 133 score_name: [score[i].tolist() for score in scores] 134 for i, score_name in enumerate(DetoxifyHelperModel.get_score_names()) 135 } 136 137 @staticmethod 138 def get_score_names() -> List[str]: 139 """ 140 Util method to return name of scores generated by helper model 141 :returns: List of score names 142 """ 143 return DETOXIFY_SCORE_NAMES
Helper model for Detoxify: https://github.com/unitaryai/detoxify
Note: we load the unbiased model directly from the state dict due to dependency conflicts between detoxify and transformers libraries.
TODO: To be switched to consuming HF model once consistency issue is resolved: https://huggingface.co/unitary/unbiased-toxic-roberta. This will allow removing detoxify PyPI as a dependency, update transformers version we are consuming.
102 def __init__(self): 103 """ 104 Constructor to locally load the helper model for inference. 105 """ 106 state_dict = torch.hub.load_state_dict_from_url(self.UNBIASED_MODEL_URL, map_location="cpu") 107 config = state_dict["config"]["arch"]["args"] 108 self._model = ( 109 getattr(transformers, config["model_name"]) 110 .from_pretrained( 111 pretrained_model_name_or_path=None, 112 config=AutoConfig.from_pretrained(config["model_type"], num_labels=config["num_classes"]), 113 state_dict=state_dict["state_dict"], 114 local_files_only=False, 115 ) 116 .to("cpu") 117 ) 118 self._tokenizer = getattr(transformers, config["tokenizer_name"]).from_pretrained(config["model_type"])
Constructor to locally load the helper model for inference.
124 def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]: # type: ignore[override] 125 """ 126 Method to get scores from DetoxifyHelper 127 :param text_input: list of text inputs for the model 128 :returns: dict with keys as score name and value being list of scores for text inputs 129 """ 130 inputs = self._tokenizer(text_input, return_tensors="pt", truncation=True, padding=True).to(self._model.device) 131 scores = torch.sigmoid(self._model(**inputs)[0]).cpu().detach().numpy() 132 return { 133 score_name: [score[i].tolist() for score in scores] 134 for i, score_name in enumerate(DetoxifyHelperModel.get_score_names()) 135 }
Method to get scores from DetoxifyHelper
Parameters
- text_input: list of text inputs for the model :returns: dict with keys as score name and value being list of scores for text inputs
137 @staticmethod 138 def get_score_names() -> List[str]: 139 """ 140 Util method to return name of scores generated by helper model 141 :returns: List of score names 142 """ 143 return DETOXIFY_SCORE_NAMES
Util method to return name of scores generated by helper model :returns: List of score names
146class BertscoreHelperModel(BaseHelperModel): 147 """ 148 BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences 149 under a (learned) model, typically, from the BERT family. 150 This score may lead to increased flexibility compared to rouge and METEOR in terms of rephrasing since 151 semantically similar sentences are (typically) embedded similarly. 152 https://huggingface.co/spaces/evaluate-metric/bertscore 153 Note: we specify that this Ray actor requires num_cpus=1 in order to limit the number of concurrently 154 running tasks or actors to avoid out of memory issues. 155 See https://docs.ray.io/en/latest/ray-core/patterns/limit-running-tasks.html#core-patterns-limit-running-tasks 156 for a detailed explanation. 157 """ 158 159 def __init__(self, model_type: str): # pragma: no cover 160 """ 161 Default constructor 162 :param model_type: Model type to be used for bertscore 163 """ 164 self._bertscore = hf_evaluate.load("bertscore") 165 self._model_type = model_type 166 167 def __reduce__(self): 168 """Serializer method so that instances of this class can be made into shared resources.""" 169 return self.__class__, (self._model_type,) 170 171 def get_helper_scores(self, target_output: str, model_output: str) -> float: # type: ignore[override] 172 """ 173 Method to invoke the concerned model and get bertscore 174 :param target_output: Reference text 175 :model_output: Model prediction text 176 """ 177 # Note: the following code is covered by unit tests, 178 # but since it gets executed by Ray, Mypy marks it 179 # as not covered. 180 return self._bertscore.compute( # pragma: no cover 181 predictions=[model_output], 182 references=[target_output], 183 model_type=self._model_type, 184 )["f1"][0]
BERTscore is a similarity-based metric that compares the embedding of the prediction and target sentences under a (learned) model, typically, from the BERT family. This score may lead to increased flexibility compared to rouge and METEOR in terms of rephrasing since semantically similar sentences are (typically) embedded similarly. https://huggingface.co/spaces/evaluate-metric/bertscore Note: we specify that this Ray actor requires num_cpus=1 in order to limit the number of concurrently running tasks or actors to avoid out of memory issues. See https://docs.ray.io/en/latest/ray-core/patterns/limit-running-tasks.html#core-patterns-limit-running-tasks for a detailed explanation.
159 def __init__(self, model_type: str): # pragma: no cover 160 """ 161 Default constructor 162 :param model_type: Model type to be used for bertscore 163 """ 164 self._bertscore = hf_evaluate.load("bertscore") 165 self._model_type = model_type
Default constructor
Parameters
- model_type: Model type to be used for bertscore
171 def get_helper_scores(self, target_output: str, model_output: str) -> float: # type: ignore[override] 172 """ 173 Method to invoke the concerned model and get bertscore 174 :param target_output: Reference text 175 :model_output: Model prediction text 176 """ 177 # Note: the following code is covered by unit tests, 178 # but since it gets executed by Ray, Mypy marks it 179 # as not covered. 180 return self._bertscore.compute( # pragma: no cover 181 predictions=[model_output], 182 references=[target_output], 183 model_type=self._model_type, 184 )["f1"][0]
Method to invoke the concerned model and get bertscore
Parameters
- target_output: Reference text :model_output: Model prediction text
187class BertscoreHelperModelTypes(str, Enum): 188 """This class holds the names of all the allowed models for computing the BERTScore.""" 189 190 MICROSOFT_DEBERTA_MODEL = "microsoft/deberta-xlarge-mnli" 191 ROBERTA_MODEL = "roberta-large-mnli" 192 193 @classmethod 194 def model_is_allowed(cls, model_name: str) -> bool: 195 """ 196 Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore. 197 """ 198 for elem in iter(cls): 199 # Because this is a (str, Enum), need cast to keep mypy happy: 200 if cast(BertscoreHelperModelTypes, elem).value == model_name: 201 return True 202 return False 203 204 @classmethod 205 def model_list(cls) -> List[str]: 206 """ 207 Return a list of all the allowed models for computing BERTScore. 208 """ 209 # Because this is a (str, Enum), need cast to keep mypy happy: 210 return [cast(BertscoreHelperModelTypes, elem).value for elem in iter(cls)]
This class holds the names of all the allowed models for computing the BERTScore.
193 @classmethod 194 def model_is_allowed(cls, model_name: str) -> bool: 195 """ 196 Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore. 197 """ 198 for elem in iter(cls): 199 # Because this is a (str, Enum), need cast to keep mypy happy: 200 if cast(BertscoreHelperModelTypes, elem).value == model_name: 201 return True 202 return False
Given a model name like 'roberta-large-mnli', check if this is an allowed model for computing BERTScore.
204 @classmethod 205 def model_list(cls) -> List[str]: 206 """ 207 Return a list of all the allowed models for computing BERTScore. 208 """ 209 # Because this is a (str, Enum), need cast to keep mypy happy: 210 return [cast(BertscoreHelperModelTypes, elem).value for elem in iter(cls)]
Return a list of all the allowed models for computing BERTScore.