fmeval.eval_algorithms.semantic_robustness_utils

  1from dataclasses import dataclass
  2from typing import Tuple
  3
  4from fmeval.constants import BUTTER_FINGER, RANDOM_UPPER_CASE, WHITESPACE_ADD_REMOVE, DatasetColumns
  5from fmeval.eval_algorithms.eval_algorithm import EvalAlgorithmConfig
  6from fmeval.model_runners.model_runner import ModelRunner
  7from fmeval.transforms.common import GeneratePrompt, GetModelOutputs
  8from fmeval.transforms.semantic_perturbations import (
  9    SemanticPerturbation,
 10    ButterFinger,
 11    RandomUppercase,
 12    AddRemoveWhitespace,
 13)
 14from fmeval.transforms.util import create_output_key
 15from fmeval.util import require
 16
 17SEMANTIC_PERTURBATIONS = {
 18    BUTTER_FINGER: ButterFinger,
 19    RANDOM_UPPER_CASE: RandomUppercase,
 20    WHITESPACE_ADD_REMOVE: AddRemoveWhitespace,
 21}
 22
 23
 24@dataclass(frozen=True)
 25class SemanticRobustnessConfig(EvalAlgorithmConfig):
 26    """Configures the semantic robustness evaluation algorithms.
 27
 28    :param perturbation_type: Perturbation type for generating perturbed inputs.
 29        Either BUTTER_FINGER, RANDOM_UPPER_CASE, or WHITESPACE_ADD_REMOVE.
 30    :param num_perturbations: Number of perturbed outputs to be generated for robustness evaluation.
 31    :param butter_finger_perturbation_prob: The probability that a given character will be perturbed.
 32        Used when perturbation_type is BUTTER_FINGER.
 33    :param random_uppercase_corrupt_proportion: Fraction of characters to be changed to uppercase.
 34        Used when perturbation_type is RANDOM_UPPER_CASE.
 35    :param whitespace_remove_prob: The probability of removing a whitespace character.
 36        Used when perturbation_type is WHITESPACE_ADD_REMOVE.
 37    :param whitespace_add_prob: The probability of adding a whitespace character after a non-whitespace character.
 38        Used when perturbation_type is WHITESPACE_ADD_REMOVE.
 39    """
 40
 41    perturbation_type: str = BUTTER_FINGER
 42    num_perturbations: int = 5
 43    butter_finger_perturbation_prob: float = 0.1
 44    random_uppercase_corrupt_proportion: float = 0.1
 45    whitespace_add_prob: float = 0.05
 46    whitespace_remove_prob: float = 0.1
 47
 48    def __post_init__(self):
 49        require(
 50            self.perturbation_type in SEMANTIC_PERTURBATIONS,
 51            f"Invalid perturbation type '{self.perturbation_type} requested, please "
 52            f"choose from acceptable values: {SEMANTIC_PERTURBATIONS.keys()}",
 53        )
 54
 55
 56def get_perturbation_transform(config: SemanticRobustnessConfig) -> SemanticPerturbation:
 57    """Returns a semantic perturbation transform based on parameters in `config`.
 58
 59    :param config: A config that specifies a perturbation type, which dictates the
 60        SemanticPerturbation that gets returned, and its configurable parameters.
 61    :returns: A SemanticPerturbation instance, initialized with parameters passed via `config`.
 62    """
 63    if config.perturbation_type == BUTTER_FINGER:
 64        return ButterFinger(
 65            input_key=DatasetColumns.MODEL_INPUT.value.name,
 66            output_keys=[
 67                create_output_key(ButterFinger.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
 68                for i in range(config.num_perturbations)
 69            ],
 70            num_perturbations=config.num_perturbations,
 71            perturbation_prob=config.butter_finger_perturbation_prob,
 72        )
 73    elif config.perturbation_type == RANDOM_UPPER_CASE:
 74        return RandomUppercase(
 75            input_key=DatasetColumns.MODEL_INPUT.value.name,
 76            output_keys=[
 77                create_output_key(RandomUppercase.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
 78                for i in range(config.num_perturbations)
 79            ],
 80            num_perturbations=config.num_perturbations,
 81            uppercase_fraction=config.random_uppercase_corrupt_proportion,
 82        )
 83    else:
 84        return AddRemoveWhitespace(
 85            input_key=DatasetColumns.MODEL_INPUT.value.name,
 86            output_keys=[
 87                create_output_key(AddRemoveWhitespace.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
 88                for i in range(config.num_perturbations)
 89            ],
 90            num_perturbations=config.num_perturbations,
 91            add_prob=config.whitespace_add_prob,
 92            remove_prob=config.whitespace_remove_prob,
 93        )
 94
 95
 96def get_model_outputs_from_perturbed_inputs(
 97    perturbation: SemanticPerturbation,
 98    prompt_template: str,
 99    model: ModelRunner,
100) -> Tuple[SemanticPerturbation, GeneratePrompt, GetModelOutputs]:
101    """Returns a tuple of transforms for perturbing model inputs, composing prompts, and getting model outputs.
102
103    :param perturbation: The semantic perturbation transform used to perturb inputs.
104    :param prompt_template: The template used for composing prompts out of the perturbed inputs.
105    :param model: The model that is invoked on the prompts constructed from perturbed inputs.
106    :returns: A tuple of three transforms, where the first is the same SemanticPerturbation
107        that was passed in, and the second two are created in this function.
108    """
109    # Generate prompts from perturbed inputs
110    gen_perturbed_prompts = GeneratePrompt(
111        input_keys=perturbation.output_keys,
112        output_keys=[
113            create_output_key(GeneratePrompt.__name__, perturbed_input_key)
114            for perturbed_input_key in perturbation.output_keys
115        ],
116        prompt_template=prompt_template,
117    )
118
119    # Invoke model with prompts generated above
120    get_perturbed_outputs = GetModelOutputs(
121        input_to_output_keys={
122            perturbed_prompt_key: [create_output_key(GetModelOutputs.__name__, perturbed_prompt_key)]
123            for perturbed_prompt_key in gen_perturbed_prompts.output_keys
124        },
125        model_runner=model,
126    )
127
128    return perturbation, gen_perturbed_prompts, get_perturbed_outputs
SEMANTIC_PERTURBATIONS = {'butter_finger': <class 'fmeval.transforms.semantic_perturbations.ButterFinger'>, 'random_upper_case': <class 'fmeval.transforms.semantic_perturbations.RandomUppercase'>, 'whitespace_add_remove': <class 'fmeval.transforms.semantic_perturbations.AddRemoveWhitespace'>}
@dataclass(frozen=True)
class SemanticRobustnessConfig(fmeval.eval_algorithms.eval_algorithm.EvalAlgorithmConfig):
25@dataclass(frozen=True)
26class SemanticRobustnessConfig(EvalAlgorithmConfig):
27    """Configures the semantic robustness evaluation algorithms.
28
29    :param perturbation_type: Perturbation type for generating perturbed inputs.
30        Either BUTTER_FINGER, RANDOM_UPPER_CASE, or WHITESPACE_ADD_REMOVE.
31    :param num_perturbations: Number of perturbed outputs to be generated for robustness evaluation.
32    :param butter_finger_perturbation_prob: The probability that a given character will be perturbed.
33        Used when perturbation_type is BUTTER_FINGER.
34    :param random_uppercase_corrupt_proportion: Fraction of characters to be changed to uppercase.
35        Used when perturbation_type is RANDOM_UPPER_CASE.
36    :param whitespace_remove_prob: The probability of removing a whitespace character.
37        Used when perturbation_type is WHITESPACE_ADD_REMOVE.
38    :param whitespace_add_prob: The probability of adding a whitespace character after a non-whitespace character.
39        Used when perturbation_type is WHITESPACE_ADD_REMOVE.
40    """
41
42    perturbation_type: str = BUTTER_FINGER
43    num_perturbations: int = 5
44    butter_finger_perturbation_prob: float = 0.1
45    random_uppercase_corrupt_proportion: float = 0.1
46    whitespace_add_prob: float = 0.05
47    whitespace_remove_prob: float = 0.1
48
49    def __post_init__(self):
50        require(
51            self.perturbation_type in SEMANTIC_PERTURBATIONS,
52            f"Invalid perturbation type '{self.perturbation_type} requested, please "
53            f"choose from acceptable values: {SEMANTIC_PERTURBATIONS.keys()}",
54        )

Configures the semantic robustness evaluation algorithms.

Parameters
  • perturbation_type: Perturbation type for generating perturbed inputs. Either BUTTER_FINGER, RANDOM_UPPER_CASE, or WHITESPACE_ADD_REMOVE.
  • num_perturbations: Number of perturbed outputs to be generated for robustness evaluation.
  • butter_finger_perturbation_prob: The probability that a given character will be perturbed. Used when perturbation_type is BUTTER_FINGER.
  • random_uppercase_corrupt_proportion: Fraction of characters to be changed to uppercase. Used when perturbation_type is RANDOM_UPPER_CASE.
  • whitespace_remove_prob: The probability of removing a whitespace character. Used when perturbation_type is WHITESPACE_ADD_REMOVE.
  • whitespace_add_prob: The probability of adding a whitespace character after a non-whitespace character. Used when perturbation_type is WHITESPACE_ADD_REMOVE.
SemanticRobustnessConfig( perturbation_type: str = 'butter_finger', num_perturbations: int = 5, butter_finger_perturbation_prob: float = 0.1, random_uppercase_corrupt_proportion: float = 0.1, whitespace_add_prob: float = 0.05, whitespace_remove_prob: float = 0.1)
perturbation_type: str = 'butter_finger'
num_perturbations: int = 5
butter_finger_perturbation_prob: float = 0.1
random_uppercase_corrupt_proportion: float = 0.1
whitespace_add_prob: float = 0.05
whitespace_remove_prob: float = 0.1
def get_perturbation_transform( config: SemanticRobustnessConfig) -> fmeval.transforms.semantic_perturbations.SemanticPerturbation:
57def get_perturbation_transform(config: SemanticRobustnessConfig) -> SemanticPerturbation:
58    """Returns a semantic perturbation transform based on parameters in `config`.
59
60    :param config: A config that specifies a perturbation type, which dictates the
61        SemanticPerturbation that gets returned, and its configurable parameters.
62    :returns: A SemanticPerturbation instance, initialized with parameters passed via `config`.
63    """
64    if config.perturbation_type == BUTTER_FINGER:
65        return ButterFinger(
66            input_key=DatasetColumns.MODEL_INPUT.value.name,
67            output_keys=[
68                create_output_key(ButterFinger.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
69                for i in range(config.num_perturbations)
70            ],
71            num_perturbations=config.num_perturbations,
72            perturbation_prob=config.butter_finger_perturbation_prob,
73        )
74    elif config.perturbation_type == RANDOM_UPPER_CASE:
75        return RandomUppercase(
76            input_key=DatasetColumns.MODEL_INPUT.value.name,
77            output_keys=[
78                create_output_key(RandomUppercase.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
79                for i in range(config.num_perturbations)
80            ],
81            num_perturbations=config.num_perturbations,
82            uppercase_fraction=config.random_uppercase_corrupt_proportion,
83        )
84    else:
85        return AddRemoveWhitespace(
86            input_key=DatasetColumns.MODEL_INPUT.value.name,
87            output_keys=[
88                create_output_key(AddRemoveWhitespace.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
89                for i in range(config.num_perturbations)
90            ],
91            num_perturbations=config.num_perturbations,
92            add_prob=config.whitespace_add_prob,
93            remove_prob=config.whitespace_remove_prob,
94        )

Returns a semantic perturbation transform based on parameters in config.

Parameters
  • config: A config that specifies a perturbation type, which dictates the SemanticPerturbation that gets returned, and its configurable parameters. :returns: A SemanticPerturbation instance, initialized with parameters passed via config.
 97def get_model_outputs_from_perturbed_inputs(
 98    perturbation: SemanticPerturbation,
 99    prompt_template: str,
100    model: ModelRunner,
101) -> Tuple[SemanticPerturbation, GeneratePrompt, GetModelOutputs]:
102    """Returns a tuple of transforms for perturbing model inputs, composing prompts, and getting model outputs.
103
104    :param perturbation: The semantic perturbation transform used to perturb inputs.
105    :param prompt_template: The template used for composing prompts out of the perturbed inputs.
106    :param model: The model that is invoked on the prompts constructed from perturbed inputs.
107    :returns: A tuple of three transforms, where the first is the same SemanticPerturbation
108        that was passed in, and the second two are created in this function.
109    """
110    # Generate prompts from perturbed inputs
111    gen_perturbed_prompts = GeneratePrompt(
112        input_keys=perturbation.output_keys,
113        output_keys=[
114            create_output_key(GeneratePrompt.__name__, perturbed_input_key)
115            for perturbed_input_key in perturbation.output_keys
116        ],
117        prompt_template=prompt_template,
118    )
119
120    # Invoke model with prompts generated above
121    get_perturbed_outputs = GetModelOutputs(
122        input_to_output_keys={
123            perturbed_prompt_key: [create_output_key(GetModelOutputs.__name__, perturbed_prompt_key)]
124            for perturbed_prompt_key in gen_perturbed_prompts.output_keys
125        },
126        model_runner=model,
127    )
128
129    return perturbation, gen_perturbed_prompts, get_perturbed_outputs

Returns a tuple of transforms for perturbing model inputs, composing prompts, and getting model outputs.

Parameters
  • perturbation: The semantic perturbation transform used to perturb inputs.
  • prompt_template: The template used for composing prompts out of the perturbed inputs.
  • model: The model that is invoked on the prompts constructed from perturbed inputs. :returns: A tuple of three transforms, where the first is the same SemanticPerturbation that was passed in, and the second two are created in this function.