fmeval.model_runners.util

Utilities for model runners.

  1"""
  2Utilities for model runners.
  3"""
  4import logging
  5import os
  6import json
  7from urllib import request
  8from typing import Literal
  9import boto3
 10import botocore.session
 11import botocore.config
 12import sagemaker
 13from functional import seq
 14
 15from fmeval.constants import (
 16    SAGEMAKER_SERVICE_ENDPOINT_URL,
 17    SAGEMAKER_RUNTIME_ENDPOINT_URL,
 18    DISABLE_FMEVAL_TELEMETRY,
 19    MODEL_ID,
 20    PROPRIETARY_SDK_MANIFEST_FILE,
 21    JUMPSTART_BUCKET_BASE_URL_FORMAT,
 22    JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR,
 23)
 24from fmeval.util import get_fmeval_package_version
 25from mypy_boto3_bedrock.client import BedrockClient
 26from sagemaker.user_agent import get_user_agent_extra_suffix
 27from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
 28
 29logger = logging.getLogger(__name__)
 30
 31
 32def get_user_agent_extra() -> str:
 33    """Return a string containing various user-agent headers to be passed to a botocore config.
 34
 35    This string will always contain SageMaker Python SDK headers obtained using the determine_prefix
 36    utility function from sagemaker.user_agent. If fmeval telemetry is enabled, this string will
 37    additionally contain an fmeval-specific header.
 38
 39    :return: A string to be used as the user_agent_extra parameter in a botocore config.
 40    """
 41    # Obtain user-agent headers for information such as SageMaker notebook instance type and SageMaker Studio app type.
 42    # We manually obtain these headers, so we can pass them in the user_agent_extra parameter of botocore.config.Config.
 43    # We can't rely on sagemaker.session.Session's initializer to fill in these headers for us, since we want to pass
 44    # our own sagemaker_client and sagemaker_runtime_client when creating a sagemaker.session.Session object.
 45    # When you pass these to the initializer, the python SDK code for constructing a botocore config with the SDK
 46    # headers won't get run.
 47    sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
 48    return (
 49        sagemaker_python_sdk_headers
 50        if os.getenv(DISABLE_FMEVAL_TELEMETRY)
 51        else f"{sagemaker_python_sdk_headers} lib/fmeval#{get_fmeval_package_version()}"
 52    )
 53
 54
 55def get_boto_session(
 56    boto_retry_mode: Literal["legacy", "standard", "adaptive"],
 57    retry_attempts: int,
 58) -> boto3.session.Session:
 59    """
 60    Get boto3 session with adaptive retry config
 61    :return: The new session
 62    """
 63    botocore_session: botocore.session.Session = botocore.session.get_session()
 64    botocore_session.set_default_client_config(
 65        botocore.config.Config(
 66            # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
 67            retries={"mode": boto_retry_mode, "max_attempts": retry_attempts},
 68            # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
 69            user_agent_extra=get_user_agent_extra(),
 70        )
 71    )
 72    return boto3.session.Session(botocore_session=botocore_session)
 73
 74
 75def get_sagemaker_session(
 76    boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive",
 77    retry_attempts: int = 10,
 78) -> sagemaker.Session:
 79    """
 80    Get SageMaker session with adaptive retry config.
 81    :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
 82    :param retry_attempts: max retry attempts used for botocore client failures
 83    :return: The new session
 84    """
 85    boto_session = get_boto_session(boto_retry_mode, retry_attempts)
 86    sagemaker_service_endpoint_url = os.getenv(SAGEMAKER_SERVICE_ENDPOINT_URL)
 87    sagemaker_runtime_endpoint_url = os.getenv(SAGEMAKER_RUNTIME_ENDPOINT_URL)
 88    sagemaker_client = boto_session.client(
 89        service_name="sagemaker",
 90        endpoint_url=sagemaker_service_endpoint_url,
 91    )
 92    sagemaker_runtime_client = boto_session.client(
 93        service_name="sagemaker-runtime",
 94        endpoint_url=sagemaker_runtime_endpoint_url,
 95    )
 96    sagemaker_session = sagemaker.session.Session(
 97        boto_session=boto_session,
 98        sagemaker_client=sagemaker_client,
 99        sagemaker_runtime_client=sagemaker_runtime_client,
100    )
101    return sagemaker_session
102
103
104def get_bedrock_runtime_client(
105    boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive",
106    retry_attempts: int = 10,
107) -> BedrockClient:
108    """
109    Get Bedrock runtime client with adaptive retry config.
110    :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
111    :param retry_attempts: max retry attempts used for botocore client failures
112    :return: The new session
113    """
114    boto_session = get_boto_session(boto_retry_mode, retry_attempts)
115    bedrock_runtime_client = boto_session.client(service_name="bedrock-runtime")
116    return bedrock_runtime_client
117
118
119def is_endpoint_in_service(
120    sagemaker_session: sagemaker.session.Session,
121    endpoint_name: str,
122) -> bool:
123    """
124    :param sagemaker_session: SageMaker session to be reused.
125    :param endpoint_name: SageMaker endpoint name.
126    :return: Whether the endpoint is in service
127    """
128    in_service = True
129    desc = sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
130    if not desc or "EndpointStatus" not in desc or desc["EndpointStatus"] != "InService":
131        in_service = False
132    return in_service
133
134
135def is_text_embedding_js_model(jumpstart_model_id: str) -> bool:
136    """
137    :param jumpstart_model_id: JumpStart model id.
138    :return: Whether the provided model id is text embedding model or not.
139    """
140    text_embedding_models = list_jumpstart_models("search_keywords includes Text Embedding")
141    return jumpstart_model_id in text_embedding_models
142
143
144def is_proprietary_js_model(region: str, jumpstart_model_id: str) -> bool:
145    """
146    :param region: Region of the JumpStart bucket.
147    :param jumpstart_model_id: JumpStart model id.
148    :return: Whether the provided model id is proprietary model or not.
149    """
150    jumpstart_bucket_base_url = os.environ.get(
151        JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR, JUMPSTART_BUCKET_BASE_URL_FORMAT
152    ).format(region, region)
153    proprietary_url = "{}/{}".format(jumpstart_bucket_base_url, PROPRIETARY_SDK_MANIFEST_FILE)
154
155    with request.urlopen(proprietary_url) as f:
156        proprietary_models_manifest = f.read().decode("utf-8")
157
158    model = seq(json.loads(proprietary_models_manifest)).find(lambda x: x.get(MODEL_ID, None) == jumpstart_model_id)
159
160    return model is not None
logger = <Logger fmeval.model_runners.util (WARNING)>
def get_user_agent_extra() -> str:
33def get_user_agent_extra() -> str:
34    """Return a string containing various user-agent headers to be passed to a botocore config.
35
36    This string will always contain SageMaker Python SDK headers obtained using the determine_prefix
37    utility function from sagemaker.user_agent. If fmeval telemetry is enabled, this string will
38    additionally contain an fmeval-specific header.
39
40    :return: A string to be used as the user_agent_extra parameter in a botocore config.
41    """
42    # Obtain user-agent headers for information such as SageMaker notebook instance type and SageMaker Studio app type.
43    # We manually obtain these headers, so we can pass them in the user_agent_extra parameter of botocore.config.Config.
44    # We can't rely on sagemaker.session.Session's initializer to fill in these headers for us, since we want to pass
45    # our own sagemaker_client and sagemaker_runtime_client when creating a sagemaker.session.Session object.
46    # When you pass these to the initializer, the python SDK code for constructing a botocore config with the SDK
47    # headers won't get run.
48    sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
49    return (
50        sagemaker_python_sdk_headers
51        if os.getenv(DISABLE_FMEVAL_TELEMETRY)
52        else f"{sagemaker_python_sdk_headers} lib/fmeval#{get_fmeval_package_version()}"
53    )

Return a string containing various user-agent headers to be passed to a botocore config.

This string will always contain SageMaker Python SDK headers obtained using the determine_prefix utility function from sagemaker.user_agent. If fmeval telemetry is enabled, this string will additionally contain an fmeval-specific header.

Returns

A string to be used as the user_agent_extra parameter in a botocore config.

def get_boto_session( boto_retry_mode: Literal['legacy', 'standard', 'adaptive'], retry_attempts: int) -> boto3.session.Session:
56def get_boto_session(
57    boto_retry_mode: Literal["legacy", "standard", "adaptive"],
58    retry_attempts: int,
59) -> boto3.session.Session:
60    """
61    Get boto3 session with adaptive retry config
62    :return: The new session
63    """
64    botocore_session: botocore.session.Session = botocore.session.get_session()
65    botocore_session.set_default_client_config(
66        botocore.config.Config(
67            # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
68            retries={"mode": boto_retry_mode, "max_attempts": retry_attempts},
69            # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
70            user_agent_extra=get_user_agent_extra(),
71        )
72    )
73    return boto3.session.Session(botocore_session=botocore_session)

Get boto3 session with adaptive retry config

Returns

The new session

def get_sagemaker_session( boto_retry_mode: Literal['legacy', 'standard', 'adaptive'] = 'adaptive', retry_attempts: int = 10) -> sagemaker.session.Session:
 76def get_sagemaker_session(
 77    boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive",
 78    retry_attempts: int = 10,
 79) -> sagemaker.Session:
 80    """
 81    Get SageMaker session with adaptive retry config.
 82    :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
 83    :param retry_attempts: max retry attempts used for botocore client failures
 84    :return: The new session
 85    """
 86    boto_session = get_boto_session(boto_retry_mode, retry_attempts)
 87    sagemaker_service_endpoint_url = os.getenv(SAGEMAKER_SERVICE_ENDPOINT_URL)
 88    sagemaker_runtime_endpoint_url = os.getenv(SAGEMAKER_RUNTIME_ENDPOINT_URL)
 89    sagemaker_client = boto_session.client(
 90        service_name="sagemaker",
 91        endpoint_url=sagemaker_service_endpoint_url,
 92    )
 93    sagemaker_runtime_client = boto_session.client(
 94        service_name="sagemaker-runtime",
 95        endpoint_url=sagemaker_runtime_endpoint_url,
 96    )
 97    sagemaker_session = sagemaker.session.Session(
 98        boto_session=boto_session,
 99        sagemaker_client=sagemaker_client,
100        sagemaker_runtime_client=sagemaker_runtime_client,
101    )
102    return sagemaker_session

Get SageMaker session with adaptive retry config.

Parameters
  • boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
  • retry_attempts: max retry attempts used for botocore client failures
Returns

The new session

def get_bedrock_runtime_client( boto_retry_mode: Literal['legacy', 'standard', 'adaptive'] = 'adaptive', retry_attempts: int = 10) -> mypy_boto3_bedrock.client.BedrockClient:
105def get_bedrock_runtime_client(
106    boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive",
107    retry_attempts: int = 10,
108) -> BedrockClient:
109    """
110    Get Bedrock runtime client with adaptive retry config.
111    :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
112    :param retry_attempts: max retry attempts used for botocore client failures
113    :return: The new session
114    """
115    boto_session = get_boto_session(boto_retry_mode, retry_attempts)
116    bedrock_runtime_client = boto_session.client(service_name="bedrock-runtime")
117    return bedrock_runtime_client

Get Bedrock runtime client with adaptive retry config.

Parameters
  • boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
  • retry_attempts: max retry attempts used for botocore client failures
Returns

The new session

def is_endpoint_in_service(sagemaker_session: sagemaker.session.Session, endpoint_name: str) -> bool:
120def is_endpoint_in_service(
121    sagemaker_session: sagemaker.session.Session,
122    endpoint_name: str,
123) -> bool:
124    """
125    :param sagemaker_session: SageMaker session to be reused.
126    :param endpoint_name: SageMaker endpoint name.
127    :return: Whether the endpoint is in service
128    """
129    in_service = True
130    desc = sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
131    if not desc or "EndpointStatus" not in desc or desc["EndpointStatus"] != "InService":
132        in_service = False
133    return in_service
Parameters
  • sagemaker_session: SageMaker session to be reused.
  • endpoint_name: SageMaker endpoint name.
Returns

Whether the endpoint is in service

def is_text_embedding_js_model(jumpstart_model_id: str) -> bool:
136def is_text_embedding_js_model(jumpstart_model_id: str) -> bool:
137    """
138    :param jumpstart_model_id: JumpStart model id.
139    :return: Whether the provided model id is text embedding model or not.
140    """
141    text_embedding_models = list_jumpstart_models("search_keywords includes Text Embedding")
142    return jumpstart_model_id in text_embedding_models
Parameters
  • jumpstart_model_id: JumpStart model id.
Returns

Whether the provided model id is text embedding model or not.

def is_proprietary_js_model(region: str, jumpstart_model_id: str) -> bool:
145def is_proprietary_js_model(region: str, jumpstart_model_id: str) -> bool:
146    """
147    :param region: Region of the JumpStart bucket.
148    :param jumpstart_model_id: JumpStart model id.
149    :return: Whether the provided model id is proprietary model or not.
150    """
151    jumpstart_bucket_base_url = os.environ.get(
152        JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR, JUMPSTART_BUCKET_BASE_URL_FORMAT
153    ).format(region, region)
154    proprietary_url = "{}/{}".format(jumpstart_bucket_base_url, PROPRIETARY_SDK_MANIFEST_FILE)
155
156    with request.urlopen(proprietary_url) as f:
157        proprietary_models_manifest = f.read().decode("utf-8")
158
159    model = seq(json.loads(proprietary_models_manifest)).find(lambda x: x.get(MODEL_ID, None) == jumpstart_model_id)
160
161    return model is not None
Parameters
  • region: Region of the JumpStart bucket.
  • jumpstart_model_id: JumpStart model id.
Returns

Whether the provided model id is proprietary model or not.