fmeval.eval_algorithms.semantic_robustness_utils
1from dataclasses import dataclass 2from typing import Tuple 3 4from fmeval.constants import BUTTER_FINGER, RANDOM_UPPER_CASE, WHITESPACE_ADD_REMOVE, DatasetColumns 5from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmConfig 6from fmeval.model_runners.model_runner import ModelRunner 7from fmeval.transforms.common import GeneratePrompt, GetModelOutputs 8from fmeval.transforms.semantic_perturbations import ( 9 SemanticPerturbation, 10 ButterFinger, 11 RandomUppercase, 12 AddRemoveWhitespace, 13) 14from fmeval.transforms.util import create_output_key 15from fmeval.util import require 16 17SEMANTIC_PERTURBATIONS = { 18 BUTTER_FINGER: ButterFinger, 19 RANDOM_UPPER_CASE: RandomUppercase, 20 WHITESPACE_ADD_REMOVE: AddRemoveWhitespace, 21} 22 23 24@dataclass(frozen=True) 25class SemanticRobustnessConfig(EvalAlgorithmConfig): 26 """Configures the semantic robustness evaluation algorithms. 27 28 :param perturbation_type: Perturbation type for generating perturbed inputs. 29 Either BUTTER_FINGER, RANDOM_UPPER_CASE, or WHITESPACE_ADD_REMOVE. 30 :param num_perturbations: Number of perturbed outputs to be generated for robustness evaluation. 31 :param butter_finger_perturbation_prob: The probability that a given character will be perturbed. 32 Used when perturbation_type is BUTTER_FINGER. 33 :param random_uppercase_corrupt_proportion: Fraction of characters to be changed to uppercase. 34 Used when perturbation_type is RANDOM_UPPER_CASE. 35 :param whitespace_remove_prob: The probability of removing a whitespace character. 36 Used when perturbation_type is WHITESPACE_ADD_REMOVE. 37 :param whitespace_add_prob: The probability of adding a whitespace character after a non-whitespace character. 38 Used when perturbation_type is WHITESPACE_ADD_REMOVE. 39 """ 40 41 perturbation_type: str = BUTTER_FINGER 42 num_perturbations: int = 5 43 butter_finger_perturbation_prob: float = 0.1 44 random_uppercase_corrupt_proportion: float = 0.1 45 whitespace_add_prob: float = 0.05 46 whitespace_remove_prob: float = 0.1 47 48 def __post_init__(self): 49 require( 50 self.perturbation_type in SEMANTIC_PERTURBATIONS, 51 f"Invalid perturbation type '{self.perturbation_type} requested, please " 52 f"choose from acceptable values: {SEMANTIC_PERTURBATIONS.keys()}", 53 ) 54 55 56def get_perturbation_transform(config: SemanticRobustnessConfig) -> SemanticPerturbation: 57 """Returns a semantic perturbation transform based on parameters in `config`. 58 59 :param config: A config that specifies a perturbation type, which dictates the 60 SemanticPerturbation that gets returned, and its configurable parameters. 61 :returns: A SemanticPerturbation instance, initialized with parameters passed via `config`. 62 """ 63 if config.perturbation_type == BUTTER_FINGER: 64 return ButterFinger( 65 input_key=DatasetColumns.MODEL_INPUT.value.name, 66 output_keys=[ 67 create_output_key(ButterFinger.__name__, DatasetColumns.MODEL_INPUT.value.name, i) 68 for i in range(config.num_perturbations) 69 ], 70 num_perturbations=config.num_perturbations, 71 perturbation_prob=config.butter_finger_perturbation_prob, 72 ) 73 elif config.perturbation_type == RANDOM_UPPER_CASE: 74 return RandomUppercase( 75 input_key=DatasetColumns.MODEL_INPUT.value.name, 76 output_keys=[ 77 create_output_key(RandomUppercase.__name__, DatasetColumns.MODEL_INPUT.value.name, i) 78 for i in range(config.num_perturbations) 79 ], 80 num_perturbations=config.num_perturbations, 81 uppercase_fraction=config.random_uppercase_corrupt_proportion, 82 ) 83 else: 84 return AddRemoveWhitespace( 85 input_key=DatasetColumns.MODEL_INPUT.value.name, 86 output_keys=[ 87 create_output_key(AddRemoveWhitespace.__name__, DatasetColumns.MODEL_INPUT.value.name, i) 88 for i in range(config.num_perturbations) 89 ], 90 num_perturbations=config.num_perturbations, 91 add_prob=config.whitespace_add_prob, 92 remove_prob=config.whitespace_remove_prob, 93 ) 94 95 96def get_model_outputs_from_perturbed_inputs( 97 perturbation: SemanticPerturbation, 98 prompt_template: str, 99 model: ModelRunner, 100) -> Tuple[SemanticPerturbation, GeneratePrompt, GetModelOutputs]: 101 """Returns a tuple of transforms for perturbing model inputs, composing prompts, and getting model outputs. 102 103 :param perturbation: The semantic perturbation transform used to perturb inputs. 104 :param prompt_template: The template used for composing prompts out of the perturbed inputs. 105 :param model: The model that is invoked on the prompts constructed from perturbed inputs. 106 :returns: A tuple of three transforms, where the first is the same SemanticPerturbation 107 that was passed in, and the second two are created in this function. 108 """ 109 # Generate prompts from perturbed inputs 110 gen_perturbed_prompts = GeneratePrompt( 111 input_keys=perturbation.output_keys, 112 output_keys=[ 113 create_output_key(GeneratePrompt.__name__, perturbed_input_key) 114 for perturbed_input_key in perturbation.output_keys 115 ], 116 prompt_template=prompt_template, 117 ) 118 119 # Invoke model with prompts generated above 120 get_perturbed_outputs = GetModelOutputs( 121 input_to_output_keys={ 122 perturbed_prompt_key: [create_output_key(GetModelOutputs.__name__, perturbed_prompt_key)] 123 for perturbed_prompt_key in gen_perturbed_prompts.output_keys 124 }, 125 model_runner=model, 126 ) 127 128 return perturbation, gen_perturbed_prompts, get_perturbed_outputs
SEMANTIC_PERTURBATIONS =
{'butter_finger': <class 'fmeval.transforms.semantic_perturbations.ButterFinger'>, 'random_upper_case': <class 'fmeval.transforms.semantic_perturbations.RandomUppercase'>, 'whitespace_add_remove': <class 'fmeval.transforms.semantic_perturbations.AddRemoveWhitespace'>}
@dataclass(frozen=True)
class
SemanticRobustnessConfig25@dataclass(frozen=True) 26class SemanticRobustnessConfig(EvalAlgorithmConfig): 27 """Configures the semantic robustness evaluation algorithms. 28 29 :param perturbation_type: Perturbation type for generating perturbed inputs. 30 Either BUTTER_FINGER, RANDOM_UPPER_CASE, or WHITESPACE_ADD_REMOVE. 31 :param num_perturbations: Number of perturbed outputs to be generated for robustness evaluation. 32 :param butter_finger_perturbation_prob: The probability that a given character will be perturbed. 33 Used when perturbation_type is BUTTER_FINGER. 34 :param random_uppercase_corrupt_proportion: Fraction of characters to be changed to uppercase. 35 Used when perturbation_type is RANDOM_UPPER_CASE. 36 :param whitespace_remove_prob: The probability of removing a whitespace character. 37 Used when perturbation_type is WHITESPACE_ADD_REMOVE. 38 :param whitespace_add_prob: The probability of adding a whitespace character after a non-whitespace character. 39 Used when perturbation_type is WHITESPACE_ADD_REMOVE. 40 """ 41 42 perturbation_type: str = BUTTER_FINGER 43 num_perturbations: int = 5 44 butter_finger_perturbation_prob: float = 0.1 45 random_uppercase_corrupt_proportion: float = 0.1 46 whitespace_add_prob: float = 0.05 47 whitespace_remove_prob: float = 0.1 48 49 def __post_init__(self): 50 require( 51 self.perturbation_type in SEMANTIC_PERTURBATIONS, 52 f"Invalid perturbation type '{self.perturbation_type} requested, please " 53 f"choose from acceptable values: {SEMANTIC_PERTURBATIONS.keys()}", 54 )
Configures the semantic robustness evaluation algorithms.
Parameters
- perturbation_type: Perturbation type for generating perturbed inputs. Either BUTTER_FINGER, RANDOM_UPPER_CASE, or WHITESPACE_ADD_REMOVE.
- num_perturbations: Number of perturbed outputs to be generated for robustness evaluation.
- butter_finger_perturbation_prob: The probability that a given character will be perturbed. Used when perturbation_type is BUTTER_FINGER.
- random_uppercase_corrupt_proportion: Fraction of characters to be changed to uppercase. Used when perturbation_type is RANDOM_UPPER_CASE.
- whitespace_remove_prob: The probability of removing a whitespace character. Used when perturbation_type is WHITESPACE_ADD_REMOVE.
- whitespace_add_prob: The probability of adding a whitespace character after a non-whitespace character. Used when perturbation_type is WHITESPACE_ADD_REMOVE.
def
get_perturbation_transform( config: SemanticRobustnessConfig) -> fmeval.transforms.semantic_perturbations.SemanticPerturbation:
57def get_perturbation_transform(config: SemanticRobustnessConfig) -> SemanticPerturbation: 58 """Returns a semantic perturbation transform based on parameters in `config`. 59 60 :param config: A config that specifies a perturbation type, which dictates the 61 SemanticPerturbation that gets returned, and its configurable parameters. 62 :returns: A SemanticPerturbation instance, initialized with parameters passed via `config`. 63 """ 64 if config.perturbation_type == BUTTER_FINGER: 65 return ButterFinger( 66 input_key=DatasetColumns.MODEL_INPUT.value.name, 67 output_keys=[ 68 create_output_key(ButterFinger.__name__, DatasetColumns.MODEL_INPUT.value.name, i) 69 for i in range(config.num_perturbations) 70 ], 71 num_perturbations=config.num_perturbations, 72 perturbation_prob=config.butter_finger_perturbation_prob, 73 ) 74 elif config.perturbation_type == RANDOM_UPPER_CASE: 75 return RandomUppercase( 76 input_key=DatasetColumns.MODEL_INPUT.value.name, 77 output_keys=[ 78 create_output_key(RandomUppercase.__name__, DatasetColumns.MODEL_INPUT.value.name, i) 79 for i in range(config.num_perturbations) 80 ], 81 num_perturbations=config.num_perturbations, 82 uppercase_fraction=config.random_uppercase_corrupt_proportion, 83 ) 84 else: 85 return AddRemoveWhitespace( 86 input_key=DatasetColumns.MODEL_INPUT.value.name, 87 output_keys=[ 88 create_output_key(AddRemoveWhitespace.__name__, DatasetColumns.MODEL_INPUT.value.name, i) 89 for i in range(config.num_perturbations) 90 ], 91 num_perturbations=config.num_perturbations, 92 add_prob=config.whitespace_add_prob, 93 remove_prob=config.whitespace_remove_prob, 94 )
Returns a semantic perturbation transform based on parameters in config
.
Parameters
- config: A config that specifies a perturbation type, which dictates the
SemanticPerturbation that gets returned, and its configurable parameters.
:returns: A SemanticPerturbation instance, initialized with parameters passed via
config
.
def
get_model_outputs_from_perturbed_inputs( perturbation: fmeval.transforms.semantic_perturbations.SemanticPerturbation, prompt_template: str, model: fmeval.model_runners.model_runner.ModelRunner) -> Tuple[fmeval.transforms.semantic_perturbations.SemanticPerturbation, fmeval.transforms.common.GeneratePrompt, fmeval.transforms.common.GetModelOutputs]:
97def get_model_outputs_from_perturbed_inputs( 98 perturbation: SemanticPerturbation, 99 prompt_template: str, 100 model: ModelRunner, 101) -> Tuple[SemanticPerturbation, GeneratePrompt, GetModelOutputs]: 102 """Returns a tuple of transforms for perturbing model inputs, composing prompts, and getting model outputs. 103 104 :param perturbation: The semantic perturbation transform used to perturb inputs. 105 :param prompt_template: The template used for composing prompts out of the perturbed inputs. 106 :param model: The model that is invoked on the prompts constructed from perturbed inputs. 107 :returns: A tuple of three transforms, where the first is the same SemanticPerturbation 108 that was passed in, and the second two are created in this function. 109 """ 110 # Generate prompts from perturbed inputs 111 gen_perturbed_prompts = GeneratePrompt( 112 input_keys=perturbation.output_keys, 113 output_keys=[ 114 create_output_key(GeneratePrompt.__name__, perturbed_input_key) 115 for perturbed_input_key in perturbation.output_keys 116 ], 117 prompt_template=prompt_template, 118 ) 119 120 # Invoke model with prompts generated above 121 get_perturbed_outputs = GetModelOutputs( 122 input_to_output_keys={ 123 perturbed_prompt_key: [create_output_key(GetModelOutputs.__name__, perturbed_prompt_key)] 124 for perturbed_prompt_key in gen_perturbed_prompts.output_keys 125 }, 126 model_runner=model, 127 ) 128 129 return perturbation, gen_perturbed_prompts, get_perturbed_outputs
Returns a tuple of transforms for perturbing model inputs, composing prompts, and getting model outputs.
Parameters
- perturbation: The semantic perturbation transform used to perturb inputs.
- prompt_template: The template used for composing prompts out of the perturbed inputs.
- model: The model that is invoked on the prompts constructed from perturbed inputs. :returns: A tuple of three transforms, where the first is the same SemanticPerturbation that was passed in, and the second two are created in this function.