fmeval.eval_algorithms

  1import math
  2from dataclasses import dataclass
  3from enum import Enum
  4from typing import List, Optional, Type, Dict
  5from functional import seq
  6
  7from fmeval.constants import MIME_TYPE_JSONLINES, ABS_TOL
  8from fmeval.data_loaders.data_config import DataConfig
  9
 10
 11@dataclass(frozen=True)
 12class EvalScore:
 13    """
 14    The class that contains the aggregated scores computed for different eval offerings
 15
 16    :param name: The name of the eval score offering
 17    :param value: The aggregated score computed for the given eval offering
 18    :param error: A string error message for a failed evaluation.
 19    """
 20
 21    name: str
 22    value: Optional[float] = None
 23    error: Optional[str] = None
 24
 25    def __post_init__(self):  # pragma: no cover
 26        """Post initialisation validations for EvalScore"""
 27        assert self.value is not None or self.error is not None
 28
 29    def __eq__(self, other: Type["EvalScore"]):  # type: ignore[override]
 30        try:
 31            assert self.name == other.name
 32            if self.value is not None and other.value is not None:
 33                assert math.isclose(self.value, other.value, abs_tol=ABS_TOL)
 34                assert self.error is None
 35            else:
 36                assert self.value == other.value
 37                assert self.error == other.error
 38            return True
 39        except AssertionError:
 40            return False
 41
 42
 43class EvalAlgorithm(str, Enum):
 44    """The evaluation types supported by Amazon Foundation Model Evaluations.
 45
 46    The evaluation types are used to determine the evaluation metrics for the
 47    model.
 48    """
 49
 50    PROMPT_STEREOTYPING = "prompt_stereotyping"
 51    FACTUAL_KNOWLEDGE = "factual_knowledge"
 52    TOXICITY = "toxicity"
 53    QA_TOXICITY = "qa_toxicity"
 54    SUMMARIZATION_TOXICITY = "summarization_toxicity"
 55    GENERAL_SEMANTIC_ROBUSTNESS = "general_semantic_robustness"
 56    ACCURACY = "accuracy"
 57    QA_ACCURACY = "qa_accuracy"
 58    QA_ACCURACY_SEMANTIC_ROBUSTNESS = "qa_accuracy_semantic_robustness"
 59    SUMMARIZATION_ACCURACY = "summarization_accuracy"
 60    SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS = "summarization_accuracy_semantic_robustness"
 61    CLASSIFICATION_ACCURACY = "classification_accuracy"
 62    CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS = "classification_accuracy_semantic_robustness"
 63
 64    def __str__(self):
 65        """
 66        Returns a prettified name
 67        """
 68        return self.name.replace("_", " ")
 69
 70
 71@dataclass(frozen=True)
 72class CategoryScore:
 73    """The class that contains the aggregated scores computed across specific categories in the dataset.
 74
 75    :param name: The name of the category.
 76    :param scores: The aggregated score computed for the given category.
 77    """
 78
 79    name: str
 80    scores: List[EvalScore]
 81
 82    def __eq__(self, other: Type["CategoryScore"]):  # type: ignore[override]
 83        try:
 84            assert self.name == other.name
 85            assert len(self.scores) == len(other.scores)
 86            assert seq(self.scores).sorted(key=lambda score: score.name).zip(
 87                seq(other.scores).sorted(key=lambda score: score.name)
 88            ).filter(lambda item: item[0] == item[1]).len() == len(self.scores)
 89            return True
 90        except AssertionError:
 91            return False
 92
 93
 94@dataclass(frozen=True)
 95class EvalOutput:
 96    """
 97    The class that contains evaluation scores from `EvalAlgorithmInterface`.
 98
 99    :param eval_name: The name of the evaluation
100    :param dataset_name: The name of dataset used by eval_algo
101    :param prompt_template: A template used to compose prompts, only consumed if model_output is not provided in dataset
102    :param dataset_scores: The aggregated score computed across the whole dataset.
103    :param category_scores: A list of CategoryScore object that contain the scores for each category in the dataset.
104    :param output_path: Local path of eval output on dataset. This output contains prompt-response with
105    record wise eval scores
106    :param error: A string error message for a failed evaluation.
107    """
108
109    eval_name: str
110    dataset_name: str
111    dataset_scores: Optional[List[EvalScore]] = None
112    prompt_template: Optional[str] = None
113    category_scores: Optional[List[CategoryScore]] = None
114    output_path: Optional[str] = None
115    error: Optional[str] = None
116
117    def __post_init__(self):  # pragma: no cover
118        """Post initialisation validations for EvalOutput"""
119        assert self.dataset_scores is not None or self.error is not None
120
121        if not self.category_scores:
122            return
123
124        dataset_score_names = [eval_score.name for eval_score in self.dataset_scores]
125        if self.category_scores:
126            for category_score in self.category_scores:
127                assert len(category_score.scores) == len(self.dataset_scores)
128                assert dataset_score_names == [
129                    category_eval_score.name for category_eval_score in category_score.scores
130                ]
131
132    def __eq__(self, other: Type["EvalOutput"]):  # type: ignore[override]
133        try:
134            assert self.eval_name == other.eval_name
135            assert self.dataset_name == other.dataset_name
136            assert self.prompt_template == other.prompt_template
137            assert self.error == other.error
138            assert self.dataset_scores if other.dataset_scores else not self.dataset_scores
139            if self.dataset_scores:  # pragma: no branch
140                assert self.dataset_scores and other.dataset_scores
141                assert len(self.dataset_scores) == len(other.dataset_scores)
142                assert seq(self.dataset_scores).sorted(key=lambda x: x.name).zip(
143                    seq(other.dataset_scores).sorted(key=lambda x: x.name)
144                ).filter(lambda x: x[0] == x[1]).len() == len(self.dataset_scores)
145            assert self.category_scores if other.category_scores else not self.category_scores
146            if self.category_scores:
147                assert seq(self.category_scores).sorted(key=lambda cat_score: cat_score.name).zip(
148                    seq(other.category_scores).sorted(key=lambda cat_score: cat_score.name)
149                ).filter(lambda item: item[0] == item[1]).len() == len(self.category_scores)
150            return True
151        except AssertionError:
152            return False
153
154
155class ModelTask(str, Enum):
156    """The different types of tasks that are supported by the evaluations.
157
158    The model tasks are used to determine the evaluation metrics for the
159    model.
160    """
161
162    NO_TASK = "no_task"
163    CLASSIFICATION = "classification"
164    QUESTION_ANSWERING = "question_answering"
165    SUMMARIZATION = "summarization"
166
167
168# These mappings are not to be consumed for any use cases and is for representational purposes.
169# NO_TASK should have all keys from EvalAlgorithm
170MODEL_TASK_EVALUATION_MAP = {
171    ModelTask.NO_TASK: [
172        EvalAlgorithm.PROMPT_STEREOTYPING,
173        EvalAlgorithm.FACTUAL_KNOWLEDGE,
174        EvalAlgorithm.TOXICITY,
175        EvalAlgorithm.GENERAL_SEMANTIC_ROBUSTNESS,
176    ],
177    ModelTask.CLASSIFICATION: [
178        EvalAlgorithm.CLASSIFICATION_ACCURACY,
179        EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS,
180    ],
181    ModelTask.QUESTION_ANSWERING: [
182        EvalAlgorithm.QA_TOXICITY,
183        EvalAlgorithm.QA_ACCURACY,
184        EvalAlgorithm.QA_ACCURACY_SEMANTIC_ROBUSTNESS,
185    ],
186    ModelTask.SUMMARIZATION: [
187        EvalAlgorithm.SUMMARIZATION_TOXICITY,
188        EvalAlgorithm.SUMMARIZATION_ACCURACY,
189        EvalAlgorithm.SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS,
190    ],
191}
192
193# Constants for Built-in dataset names
194TREX = "trex"
195BOOLQ = "boolq"
196TRIVIA_QA = "trivia_qa"
197NATURAL_QUESTIONS = "natural_questions"
198CROWS_PAIRS = "crows-pairs"
199GIGAWORD = "gigaword"
200GOV_REPORT = "gov_report"
201WOMENS_CLOTHING_ECOMMERCE_REVIEWS = "womens_clothing_ecommerce_reviews"
202BOLD = "bold"
203WIKITEXT2 = "wikitext2"
204REAL_TOXICITY_PROMPTS = "real_toxicity_prompts"
205REAL_TOXICITY_PROMPTS_CHALLENGING = "real_toxicity_prompts_challenging"
206
207# Mapping of Eval algorithms and corresponding Built-in datasets
208EVAL_DATASETS: Dict[str, List[str]] = {
209    EvalAlgorithm.FACTUAL_KNOWLEDGE.value: [TREX],
210    EvalAlgorithm.QA_ACCURACY.value: [BOOLQ, TRIVIA_QA, NATURAL_QUESTIONS],
211    EvalAlgorithm.QA_ACCURACY_SEMANTIC_ROBUSTNESS.value: [BOOLQ, TRIVIA_QA, NATURAL_QUESTIONS],
212    EvalAlgorithm.PROMPT_STEREOTYPING.value: [CROWS_PAIRS],
213    EvalAlgorithm.SUMMARIZATION_ACCURACY.value: [GIGAWORD, GOV_REPORT],
214    EvalAlgorithm.GENERAL_SEMANTIC_ROBUSTNESS.value: [BOLD, TREX, WIKITEXT2],
215    EvalAlgorithm.CLASSIFICATION_ACCURACY.value: [WOMENS_CLOTHING_ECOMMERCE_REVIEWS],
216    EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS.value: [
217        WOMENS_CLOTHING_ECOMMERCE_REVIEWS,
218    ],
219    EvalAlgorithm.SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS.value: [GIGAWORD, GOV_REPORT],
220    EvalAlgorithm.TOXICITY.value: [BOLD, REAL_TOXICITY_PROMPTS, REAL_TOXICITY_PROMPTS_CHALLENGING],
221    EvalAlgorithm.QA_TOXICITY.value: [BOOLQ, TRIVIA_QA, NATURAL_QUESTIONS],
222    EvalAlgorithm.SUMMARIZATION_TOXICITY.value: [GIGAWORD, GOV_REPORT],
223}
224
225# Mapping of Default Prompt Template corresponding to eval, built-in dataset pair
226DEFAULT_PROMPT_TEMPLATE = "$model_input"
227
228BUILT_IN_DATASET_DEFAULT_PROMPT_TEMPLATES = {
229    BOOLQ: 'Respond to the following question. Valid answers are "True" or "False". $model_input',
230    TRIVIA_QA: "Respond to the following question with a short answer: $model_input",
231    NATURAL_QUESTIONS: "Respond to the following question with a short answer: $model_input",
232    GIGAWORD: "Summarize the following text in one sentence: $model_input",
233    GOV_REPORT: "Summarize the following text in a few sentences: $model_input",
234    WOMENS_CLOTHING_ECOMMERCE_REVIEWS: (
235        "Classify the sentiment of the following review with 0 (negative sentiment)"
236        " or 1 (positive sentiment): $model_input"
237    ),
238}
239
240
241def get_default_prompt_template(dataset_name: str) -> str:
242    """
243    Util method to provide dataset specific default prompt templates. If not default is configured for the dataset,
244        the method returns a generic default prompt template.
245    :param dataset_name: Name of dataset
246    """
247    return BUILT_IN_DATASET_DEFAULT_PROMPT_TEMPLATES.get(dataset_name, DEFAULT_PROMPT_TEMPLATE)
248
249
250# Mapping of Built-in dataset names and their DataConfigs
251DATASET_CONFIGS: Dict[str, DataConfig] = {
252    TREX: DataConfig(
253        dataset_name=TREX,
254        dataset_uri="s3://fmeval/datasets/trex/trex.jsonl",
255        dataset_mime_type=MIME_TYPE_JSONLINES,
256        model_input_location="question",
257        target_output_location="answers",
258        category_location="knowledge_category",
259    ),
260    BOOLQ: DataConfig(
261        dataset_name=BOOLQ,
262        dataset_uri="s3://fmeval/datasets/boolq/boolq.jsonl",
263        dataset_mime_type=MIME_TYPE_JSONLINES,
264        model_input_location="question",
265        target_output_location="answer",
266    ),
267    TRIVIA_QA: DataConfig(
268        dataset_name=TRIVIA_QA,
269        dataset_uri="s3://fmeval/datasets/triviaQA/triviaQA.jsonl",
270        dataset_mime_type=MIME_TYPE_JSONLINES,
271        model_input_location="question",
272        target_output_location="answer",
273    ),
274    NATURAL_QUESTIONS: DataConfig(
275        dataset_name=NATURAL_QUESTIONS,
276        dataset_uri="s3://fmeval/datasets/natural_questions/natural_questions.jsonl",
277        dataset_mime_type=MIME_TYPE_JSONLINES,
278        model_input_location="question",
279        target_output_location="answer",
280    ),
281    CROWS_PAIRS: DataConfig(
282        dataset_name=CROWS_PAIRS,
283        dataset_uri="s3://fmeval/datasets/crows-pairs/crows-pairs.jsonl",
284        dataset_mime_type=MIME_TYPE_JSONLINES,
285        sent_more_input_location="sent_more",
286        sent_less_input_location="sent_less",
287        category_location="bias_type",
288    ),
289    WOMENS_CLOTHING_ECOMMERCE_REVIEWS: DataConfig(
290        dataset_name=WOMENS_CLOTHING_ECOMMERCE_REVIEWS,
291        dataset_uri="s3://fmeval/datasets/womens_clothing_reviews/womens_clothing_reviews.jsonl",
292        dataset_mime_type=MIME_TYPE_JSONLINES,
293        model_input_location='"Review Text"',
294        target_output_location='"Recommended IND"',
295        category_location='"Class Name"',
296    ),
297    BOLD: DataConfig(
298        dataset_name=BOLD,
299        dataset_uri="s3://fmeval/datasets/bold/bold.jsonl",
300        dataset_mime_type=MIME_TYPE_JSONLINES,
301        model_input_location="prompt",
302        category_location="domain",
303    ),
304    WIKITEXT2: DataConfig(
305        dataset_name=WIKITEXT2,
306        dataset_uri="s3://fmeval/datasets/wikitext2/wikitext2.jsonl",
307        dataset_mime_type=MIME_TYPE_JSONLINES,
308        model_input_location="prompt",
309    ),
310    REAL_TOXICITY_PROMPTS: DataConfig(
311        dataset_name=REAL_TOXICITY_PROMPTS,
312        dataset_uri="s3://fmeval/datasets/real_toxicity/real_toxicity.jsonl",
313        dataset_mime_type=MIME_TYPE_JSONLINES,
314        model_input_location="prompt",
315    ),
316    REAL_TOXICITY_PROMPTS_CHALLENGING: DataConfig(
317        dataset_name=REAL_TOXICITY_PROMPTS_CHALLENGING,
318        dataset_uri="s3://fmeval/datasets/real_toxicity/real_toxicity_challenging.jsonl",
319        dataset_mime_type=MIME_TYPE_JSONLINES,
320        model_input_location="prompt",
321    ),
322    GIGAWORD: DataConfig(
323        dataset_name=GIGAWORD,
324        dataset_uri="s3://fmeval/datasets/gigaword/gigaword.jsonl",
325        dataset_mime_type=MIME_TYPE_JSONLINES,
326        model_input_location="document",
327        target_output_location="summary",
328    ),
329    GOV_REPORT: DataConfig(
330        dataset_name=GOV_REPORT,
331        dataset_uri="s3://fmeval/datasets/gov_report/gov_report.jsonl",
332        dataset_mime_type=MIME_TYPE_JSONLINES,
333        model_input_location="report",
334        target_output_location="summary",
335    ),
336}
@dataclass(frozen=True)
class EvalScore:
12@dataclass(frozen=True)
13class EvalScore:
14    """
15    The class that contains the aggregated scores computed for different eval offerings
16
17    :param name: The name of the eval score offering
18    :param value: The aggregated score computed for the given eval offering
19    :param error: A string error message for a failed evaluation.
20    """
21
22    name: str
23    value: Optional[float] = None
24    error: Optional[str] = None
25
26    def __post_init__(self):  # pragma: no cover
27        """Post initialisation validations for EvalScore"""
28        assert self.value is not None or self.error is not None
29
30    def __eq__(self, other: Type["EvalScore"]):  # type: ignore[override]
31        try:
32            assert self.name == other.name
33            if self.value is not None and other.value is not None:
34                assert math.isclose(self.value, other.value, abs_tol=ABS_TOL)
35                assert self.error is None
36            else:
37                assert self.value == other.value
38                assert self.error == other.error
39            return True
40        except AssertionError:
41            return False

The class that contains the aggregated scores computed for different eval offerings

Parameters
  • name: The name of the eval score offering
  • value: The aggregated score computed for the given eval offering
  • error: A string error message for a failed evaluation.
EvalScore( name: str, value: Optional[float] = None, error: Optional[str] = None)
name: str
value: Optional[float] = None
error: Optional[str] = None
class EvalAlgorithm(builtins.str, enum.Enum):
44class EvalAlgorithm(str, Enum):
45    """The evaluation types supported by Amazon Foundation Model Evaluations.
46
47    The evaluation types are used to determine the evaluation metrics for the
48    model.
49    """
50
51    PROMPT_STEREOTYPING = "prompt_stereotyping"
52    FACTUAL_KNOWLEDGE = "factual_knowledge"
53    TOXICITY = "toxicity"
54    QA_TOXICITY = "qa_toxicity"
55    SUMMARIZATION_TOXICITY = "summarization_toxicity"
56    GENERAL_SEMANTIC_ROBUSTNESS = "general_semantic_robustness"
57    ACCURACY = "accuracy"
58    QA_ACCURACY = "qa_accuracy"
59    QA_ACCURACY_SEMANTIC_ROBUSTNESS = "qa_accuracy_semantic_robustness"
60    SUMMARIZATION_ACCURACY = "summarization_accuracy"
61    SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS = "summarization_accuracy_semantic_robustness"
62    CLASSIFICATION_ACCURACY = "classification_accuracy"
63    CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS = "classification_accuracy_semantic_robustness"
64
65    def __str__(self):
66        """
67        Returns a prettified name
68        """
69        return self.name.replace("_", " ")

The evaluation types supported by Amazon Foundation Model Evaluations.

The evaluation types are used to determine the evaluation metrics for the model.

PROMPT_STEREOTYPING = <EvalAlgorithm.PROMPT_STEREOTYPING: 'prompt_stereotyping'>
FACTUAL_KNOWLEDGE = <EvalAlgorithm.FACTUAL_KNOWLEDGE: 'factual_knowledge'>
TOXICITY = <EvalAlgorithm.TOXICITY: 'toxicity'>
QA_TOXICITY = <EvalAlgorithm.QA_TOXICITY: 'qa_toxicity'>
SUMMARIZATION_TOXICITY = <EvalAlgorithm.SUMMARIZATION_TOXICITY: 'summarization_toxicity'>
GENERAL_SEMANTIC_ROBUSTNESS = <EvalAlgorithm.GENERAL_SEMANTIC_ROBUSTNESS: 'general_semantic_robustness'>
ACCURACY = <EvalAlgorithm.ACCURACY: 'accuracy'>
QA_ACCURACY = <EvalAlgorithm.QA_ACCURACY: 'qa_accuracy'>
QA_ACCURACY_SEMANTIC_ROBUSTNESS = <EvalAlgorithm.QA_ACCURACY_SEMANTIC_ROBUSTNESS: 'qa_accuracy_semantic_robustness'>
SUMMARIZATION_ACCURACY = <EvalAlgorithm.SUMMARIZATION_ACCURACY: 'summarization_accuracy'>
SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS = <EvalAlgorithm.SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS: 'summarization_accuracy_semantic_robustness'>
CLASSIFICATION_ACCURACY = <EvalAlgorithm.CLASSIFICATION_ACCURACY: 'classification_accuracy'>
CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS = <EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS: 'classification_accuracy_semantic_robustness'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
@dataclass(frozen=True)
class CategoryScore:
72@dataclass(frozen=True)
73class CategoryScore:
74    """The class that contains the aggregated scores computed across specific categories in the dataset.
75
76    :param name: The name of the category.
77    :param scores: The aggregated score computed for the given category.
78    """
79
80    name: str
81    scores: List[EvalScore]
82
83    def __eq__(self, other: Type["CategoryScore"]):  # type: ignore[override]
84        try:
85            assert self.name == other.name
86            assert len(self.scores) == len(other.scores)
87            assert seq(self.scores).sorted(key=lambda score: score.name).zip(
88                seq(other.scores).sorted(key=lambda score: score.name)
89            ).filter(lambda item: item[0] == item[1]).len() == len(self.scores)
90            return True
91        except AssertionError:
92            return False

The class that contains the aggregated scores computed across specific categories in the dataset.

Parameters
  • name: The name of the category.
  • scores: The aggregated score computed for the given category.
CategoryScore(name: str, scores: List[EvalScore])
name: str
scores: List[EvalScore]
@dataclass(frozen=True)
class EvalOutput:
 95@dataclass(frozen=True)
 96class EvalOutput:
 97    """
 98    The class that contains evaluation scores from `EvalAlgorithmInterface`.
 99
100    :param eval_name: The name of the evaluation
101    :param dataset_name: The name of dataset used by eval_algo
102    :param prompt_template: A template used to compose prompts, only consumed if model_output is not provided in dataset
103    :param dataset_scores: The aggregated score computed across the whole dataset.
104    :param category_scores: A list of CategoryScore object that contain the scores for each category in the dataset.
105    :param output_path: Local path of eval output on dataset. This output contains prompt-response with
106    record wise eval scores
107    :param error: A string error message for a failed evaluation.
108    """
109
110    eval_name: str
111    dataset_name: str
112    dataset_scores: Optional[List[EvalScore]] = None
113    prompt_template: Optional[str] = None
114    category_scores: Optional[List[CategoryScore]] = None
115    output_path: Optional[str] = None
116    error: Optional[str] = None
117
118    def __post_init__(self):  # pragma: no cover
119        """Post initialisation validations for EvalOutput"""
120        assert self.dataset_scores is not None or self.error is not None
121
122        if not self.category_scores:
123            return
124
125        dataset_score_names = [eval_score.name for eval_score in self.dataset_scores]
126        if self.category_scores:
127            for category_score in self.category_scores:
128                assert len(category_score.scores) == len(self.dataset_scores)
129                assert dataset_score_names == [
130                    category_eval_score.name for category_eval_score in category_score.scores
131                ]
132
133    def __eq__(self, other: Type["EvalOutput"]):  # type: ignore[override]
134        try:
135            assert self.eval_name == other.eval_name
136            assert self.dataset_name == other.dataset_name
137            assert self.prompt_template == other.prompt_template
138            assert self.error == other.error
139            assert self.dataset_scores if other.dataset_scores else not self.dataset_scores
140            if self.dataset_scores:  # pragma: no branch
141                assert self.dataset_scores and other.dataset_scores
142                assert len(self.dataset_scores) == len(other.dataset_scores)
143                assert seq(self.dataset_scores).sorted(key=lambda x: x.name).zip(
144                    seq(other.dataset_scores).sorted(key=lambda x: x.name)
145                ).filter(lambda x: x[0] == x[1]).len() == len(self.dataset_scores)
146            assert self.category_scores if other.category_scores else not self.category_scores
147            if self.category_scores:
148                assert seq(self.category_scores).sorted(key=lambda cat_score: cat_score.name).zip(
149                    seq(other.category_scores).sorted(key=lambda cat_score: cat_score.name)
150                ).filter(lambda item: item[0] == item[1]).len() == len(self.category_scores)
151            return True
152        except AssertionError:
153            return False

The class that contains evaluation scores from EvalAlgorithmInterface.

Parameters
  • eval_name: The name of the evaluation
  • dataset_name: The name of dataset used by eval_algo
  • prompt_template: A template used to compose prompts, only consumed if model_output is not provided in dataset
  • dataset_scores: The aggregated score computed across the whole dataset.
  • category_scores: A list of CategoryScore object that contain the scores for each category in the dataset.
  • output_path: Local path of eval output on dataset. This output contains prompt-response with record wise eval scores
  • error: A string error message for a failed evaluation.
EvalOutput( eval_name: str, dataset_name: str, dataset_scores: Optional[List[EvalScore]] = None, prompt_template: Optional[str] = None, category_scores: Optional[List[CategoryScore]] = None, output_path: Optional[str] = None, error: Optional[str] = None)
eval_name: str
dataset_name: str
dataset_scores: Optional[List[EvalScore]] = None
prompt_template: Optional[str] = None
category_scores: Optional[List[CategoryScore]] = None
output_path: Optional[str] = None
error: Optional[str] = None
class ModelTask(builtins.str, enum.Enum):
156class ModelTask(str, Enum):
157    """The different types of tasks that are supported by the evaluations.
158
159    The model tasks are used to determine the evaluation metrics for the
160    model.
161    """
162
163    NO_TASK = "no_task"
164    CLASSIFICATION = "classification"
165    QUESTION_ANSWERING = "question_answering"
166    SUMMARIZATION = "summarization"

The different types of tasks that are supported by the evaluations.

The model tasks are used to determine the evaluation metrics for the model.

NO_TASK = <ModelTask.NO_TASK: 'no_task'>
CLASSIFICATION = <ModelTask.CLASSIFICATION: 'classification'>
QUESTION_ANSWERING = <ModelTask.QUESTION_ANSWERING: 'question_answering'>
SUMMARIZATION = <ModelTask.SUMMARIZATION: 'summarization'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
MODEL_TASK_EVALUATION_MAP = {<ModelTask.NO_TASK: 'no_task'>: [<EvalAlgorithm.PROMPT_STEREOTYPING: 'prompt_stereotyping'>, <EvalAlgorithm.FACTUAL_KNOWLEDGE: 'factual_knowledge'>, <EvalAlgorithm.TOXICITY: 'toxicity'>, <EvalAlgorithm.GENERAL_SEMANTIC_ROBUSTNESS: 'general_semantic_robustness'>], <ModelTask.CLASSIFICATION: 'classification'>: [<EvalAlgorithm.CLASSIFICATION_ACCURACY: 'classification_accuracy'>, <EvalAlgorithm.CLASSIFICATION_ACCURACY_SEMANTIC_ROBUSTNESS: 'classification_accuracy_semantic_robustness'>], <ModelTask.QUESTION_ANSWERING: 'question_answering'>: [<EvalAlgorithm.QA_TOXICITY: 'qa_toxicity'>, <EvalAlgorithm.QA_ACCURACY: 'qa_accuracy'>, <EvalAlgorithm.QA_ACCURACY_SEMANTIC_ROBUSTNESS: 'qa_accuracy_semantic_robustness'>], <ModelTask.SUMMARIZATION: 'summarization'>: [<EvalAlgorithm.SUMMARIZATION_TOXICITY: 'summarization_toxicity'>, <EvalAlgorithm.SUMMARIZATION_ACCURACY: 'summarization_accuracy'>, <EvalAlgorithm.SUMMARIZATION_ACCURACY_SEMANTIC_ROBUSTNESS: 'summarization_accuracy_semantic_robustness'>]}
TREX = 'trex'
BOOLQ = 'boolq'
TRIVIA_QA = 'trivia_qa'
NATURAL_QUESTIONS = 'natural_questions'
CROWS_PAIRS = 'crows-pairs'
GIGAWORD = 'gigaword'
GOV_REPORT = 'gov_report'
WOMENS_CLOTHING_ECOMMERCE_REVIEWS = 'womens_clothing_ecommerce_reviews'
BOLD = 'bold'
WIKITEXT2 = 'wikitext2'
REAL_TOXICITY_PROMPTS = 'real_toxicity_prompts'
REAL_TOXICITY_PROMPTS_CHALLENGING = 'real_toxicity_prompts_challenging'
EVAL_DATASETS: Dict[str, List[str]] = {'factual_knowledge': ['trex'], 'qa_accuracy': ['boolq', 'trivia_qa', 'natural_questions'], 'qa_accuracy_semantic_robustness': ['boolq', 'trivia_qa', 'natural_questions'], 'prompt_stereotyping': ['crows-pairs'], 'summarization_accuracy': ['gigaword', 'gov_report'], 'general_semantic_robustness': ['bold', 'trex', 'wikitext2'], 'classification_accuracy': ['womens_clothing_ecommerce_reviews'], 'classification_accuracy_semantic_robustness': ['womens_clothing_ecommerce_reviews'], 'summarization_accuracy_semantic_robustness': ['gigaword', 'gov_report'], 'toxicity': ['bold', 'real_toxicity_prompts', 'real_toxicity_prompts_challenging'], 'qa_toxicity': ['boolq', 'trivia_qa', 'natural_questions'], 'summarization_toxicity': ['gigaword', 'gov_report']}
DEFAULT_PROMPT_TEMPLATE = '$model_input'
BUILT_IN_DATASET_DEFAULT_PROMPT_TEMPLATES = {'boolq': 'Respond to the following question. Valid answers are "True" or "False". $model_input', 'trivia_qa': 'Respond to the following question with a short answer: $model_input', 'natural_questions': 'Respond to the following question with a short answer: $model_input', 'gigaword': 'Summarize the following text in one sentence: $model_input', 'gov_report': 'Summarize the following text in a few sentences: $model_input', 'womens_clothing_ecommerce_reviews': 'Classify the sentiment of the following review with 0 (negative sentiment) or 1 (positive sentiment): $model_input'}
def get_default_prompt_template(dataset_name: str) -> str:
242def get_default_prompt_template(dataset_name: str) -> str:
243    """
244    Util method to provide dataset specific default prompt templates. If not default is configured for the dataset,
245        the method returns a generic default prompt template.
246    :param dataset_name: Name of dataset
247    """
248    return BUILT_IN_DATASET_DEFAULT_PROMPT_TEMPLATES.get(dataset_name, DEFAULT_PROMPT_TEMPLATE)

Util method to provide dataset specific default prompt templates. If not default is configured for the dataset, the method returns a generic default prompt template.

Parameters
  • dataset_name: Name of dataset
DATASET_CONFIGS: Dict[str, fmeval.data_loaders.data_config.DataConfig] = {'trex': DataConfig(dataset_name='trex', dataset_uri='s3://fmeval/datasets/trex/trex.jsonl', dataset_mime_type='application/jsonlines', model_input_location='question', model_output_location=None, target_output_location='answers', category_location='knowledge_category', sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'boolq': DataConfig(dataset_name='boolq', dataset_uri='s3://fmeval/datasets/boolq/boolq.jsonl', dataset_mime_type='application/jsonlines', model_input_location='question', model_output_location=None, target_output_location='answer', category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'trivia_qa': DataConfig(dataset_name='trivia_qa', dataset_uri='s3://fmeval/datasets/triviaQA/triviaQA.jsonl', dataset_mime_type='application/jsonlines', model_input_location='question', model_output_location=None, target_output_location='answer', category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'natural_questions': DataConfig(dataset_name='natural_questions', dataset_uri='s3://fmeval/datasets/natural_questions/natural_questions.jsonl', dataset_mime_type='application/jsonlines', model_input_location='question', model_output_location=None, target_output_location='answer', category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'crows-pairs': DataConfig(dataset_name='crows-pairs', dataset_uri='s3://fmeval/datasets/crows-pairs/crows-pairs.jsonl', dataset_mime_type='application/jsonlines', model_input_location=None, model_output_location=None, target_output_location=None, category_location='bias_type', sent_more_input_location='sent_more', sent_less_input_location='sent_less', sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'womens_clothing_ecommerce_reviews': DataConfig(dataset_name='womens_clothing_ecommerce_reviews', dataset_uri='s3://fmeval/datasets/womens_clothing_reviews/womens_clothing_reviews.jsonl', dataset_mime_type='application/jsonlines', model_input_location='"Review Text"', model_output_location=None, target_output_location='"Recommended IND"', category_location='"Class Name"', sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'bold': DataConfig(dataset_name='bold', dataset_uri='s3://fmeval/datasets/bold/bold.jsonl', dataset_mime_type='application/jsonlines', model_input_location='prompt', model_output_location=None, target_output_location=None, category_location='domain', sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'wikitext2': DataConfig(dataset_name='wikitext2', dataset_uri='s3://fmeval/datasets/wikitext2/wikitext2.jsonl', dataset_mime_type='application/jsonlines', model_input_location='prompt', model_output_location=None, target_output_location=None, category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'real_toxicity_prompts': DataConfig(dataset_name='real_toxicity_prompts', dataset_uri='s3://fmeval/datasets/real_toxicity/real_toxicity.jsonl', dataset_mime_type='application/jsonlines', model_input_location='prompt', model_output_location=None, target_output_location=None, category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'real_toxicity_prompts_challenging': DataConfig(dataset_name='real_toxicity_prompts_challenging', dataset_uri='s3://fmeval/datasets/real_toxicity/real_toxicity_challenging.jsonl', dataset_mime_type='application/jsonlines', model_input_location='prompt', model_output_location=None, target_output_location=None, category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'gigaword': DataConfig(dataset_name='gigaword', dataset_uri='s3://fmeval/datasets/gigaword/gigaword.jsonl', dataset_mime_type='application/jsonlines', model_input_location='document', model_output_location=None, target_output_location='summary', category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None), 'gov_report': DataConfig(dataset_name='gov_report', dataset_uri='s3://fmeval/datasets/gov_report/gov_report.jsonl', dataset_mime_type='application/jsonlines', model_input_location='report', model_output_location=None, target_output_location='summary', category_location=None, sent_more_input_location=None, sent_less_input_location=None, sent_more_log_prob_location=None, sent_less_log_prob_location=None, context_location=None)}