fmeval.model_runners.util
Utilities for model runners.
1""" 2Utilities for model runners. 3""" 4import logging 5import os 6import json 7from urllib import request 8from typing import Literal 9import boto3 10import botocore.session 11import botocore.config 12import sagemaker 13from functional import seq 14 15from fmeval.constants import ( 16 SAGEMAKER_SERVICE_ENDPOINT_URL, 17 SAGEMAKER_RUNTIME_ENDPOINT_URL, 18 DISABLE_FMEVAL_TELEMETRY, 19 MODEL_ID, 20 PROPRIETARY_SDK_MANIFEST_FILE, 21 JUMPSTART_BUCKET_BASE_URL_FORMAT, 22 JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR, 23) 24from fmeval.util import get_fmeval_package_version 25from mypy_boto3_bedrock.client import BedrockClient 26from sagemaker.user_agent import get_user_agent_extra_suffix 27from sagemaker.jumpstart.notebook_utils import list_jumpstart_models 28 29logger = logging.getLogger(__name__) 30 31 32def get_user_agent_extra() -> str: 33 """Return a string containing various user-agent headers to be passed to a botocore config. 34 35 This string will always contain SageMaker Python SDK headers obtained using the determine_prefix 36 utility function from sagemaker.user_agent. If fmeval telemetry is enabled, this string will 37 additionally contain an fmeval-specific header. 38 39 :return: A string to be used as the user_agent_extra parameter in a botocore config. 40 """ 41 # Obtain user-agent headers for information such as SageMaker notebook instance type and SageMaker Studio app type. 42 # We manually obtain these headers, so we can pass them in the user_agent_extra parameter of botocore.config.Config. 43 # We can't rely on sagemaker.session.Session's initializer to fill in these headers for us, since we want to pass 44 # our own sagemaker_client and sagemaker_runtime_client when creating a sagemaker.session.Session object. 45 # When you pass these to the initializer, the python SDK code for constructing a botocore config with the SDK 46 # headers won't get run. 47 sagemaker_python_sdk_headers = get_user_agent_extra_suffix() 48 return ( 49 sagemaker_python_sdk_headers 50 if os.getenv(DISABLE_FMEVAL_TELEMETRY) 51 else f"{sagemaker_python_sdk_headers} lib/fmeval#{get_fmeval_package_version()}" 52 ) 53 54 55def get_boto_session( 56 boto_retry_mode: Literal["legacy", "standard", "adaptive"], 57 retry_attempts: int, 58) -> boto3.session.Session: 59 """ 60 Get boto3 session with adaptive retry config 61 :return: The new session 62 """ 63 botocore_session: botocore.session.Session = botocore.session.get_session() 64 botocore_session.set_default_client_config( 65 botocore.config.Config( 66 # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html 67 retries={"mode": boto_retry_mode, "max_attempts": retry_attempts}, 68 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 69 user_agent_extra=get_user_agent_extra(), 70 ) 71 ) 72 return boto3.session.Session(botocore_session=botocore_session) 73 74 75def get_sagemaker_session( 76 boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive", 77 retry_attempts: int = 10, 78) -> sagemaker.Session: 79 """ 80 Get SageMaker session with adaptive retry config. 81 :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive). 82 :param retry_attempts: max retry attempts used for botocore client failures 83 :return: The new session 84 """ 85 boto_session = get_boto_session(boto_retry_mode, retry_attempts) 86 sagemaker_service_endpoint_url = os.getenv(SAGEMAKER_SERVICE_ENDPOINT_URL) 87 sagemaker_runtime_endpoint_url = os.getenv(SAGEMAKER_RUNTIME_ENDPOINT_URL) 88 sagemaker_client = boto_session.client( 89 service_name="sagemaker", 90 endpoint_url=sagemaker_service_endpoint_url, 91 ) 92 sagemaker_runtime_client = boto_session.client( 93 service_name="sagemaker-runtime", 94 endpoint_url=sagemaker_runtime_endpoint_url, 95 ) 96 sagemaker_session = sagemaker.session.Session( 97 boto_session=boto_session, 98 sagemaker_client=sagemaker_client, 99 sagemaker_runtime_client=sagemaker_runtime_client, 100 ) 101 return sagemaker_session 102 103 104def get_bedrock_runtime_client( 105 boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive", 106 retry_attempts: int = 10, 107) -> BedrockClient: 108 """ 109 Get Bedrock runtime client with adaptive retry config. 110 :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive). 111 :param retry_attempts: max retry attempts used for botocore client failures 112 :return: The new session 113 """ 114 boto_session = get_boto_session(boto_retry_mode, retry_attempts) 115 bedrock_runtime_client = boto_session.client(service_name="bedrock-runtime") 116 return bedrock_runtime_client 117 118 119def is_endpoint_in_service( 120 sagemaker_session: sagemaker.session.Session, 121 endpoint_name: str, 122) -> bool: 123 """ 124 :param sagemaker_session: SageMaker session to be reused. 125 :param endpoint_name: SageMaker endpoint name. 126 :return: Whether the endpoint is in service 127 """ 128 in_service = True 129 desc = sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) 130 if not desc or "EndpointStatus" not in desc or desc["EndpointStatus"] != "InService": 131 in_service = False 132 return in_service 133 134 135def is_text_embedding_js_model(jumpstart_model_id: str) -> bool: 136 """ 137 :param jumpstart_model_id: JumpStart model id. 138 :return: Whether the provided model id is text embedding model or not. 139 """ 140 text_embedding_models = list_jumpstart_models("search_keywords includes Text Embedding") 141 return jumpstart_model_id in text_embedding_models 142 143 144def is_proprietary_js_model(region: str, jumpstart_model_id: str) -> bool: 145 """ 146 :param region: Region of the JumpStart bucket. 147 :param jumpstart_model_id: JumpStart model id. 148 :return: Whether the provided model id is proprietary model or not. 149 """ 150 jumpstart_bucket_base_url = os.environ.get( 151 JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR, JUMPSTART_BUCKET_BASE_URL_FORMAT 152 ).format(region, region) 153 proprietary_url = "{}/{}".format(jumpstart_bucket_base_url, PROPRIETARY_SDK_MANIFEST_FILE) 154 155 with request.urlopen(proprietary_url) as f: 156 proprietary_models_manifest = f.read().decode("utf-8") 157 158 model = seq(json.loads(proprietary_models_manifest)).find(lambda x: x.get(MODEL_ID, None) == jumpstart_model_id) 159 160 return model is not None
33def get_user_agent_extra() -> str: 34 """Return a string containing various user-agent headers to be passed to a botocore config. 35 36 This string will always contain SageMaker Python SDK headers obtained using the determine_prefix 37 utility function from sagemaker.user_agent. If fmeval telemetry is enabled, this string will 38 additionally contain an fmeval-specific header. 39 40 :return: A string to be used as the user_agent_extra parameter in a botocore config. 41 """ 42 # Obtain user-agent headers for information such as SageMaker notebook instance type and SageMaker Studio app type. 43 # We manually obtain these headers, so we can pass them in the user_agent_extra parameter of botocore.config.Config. 44 # We can't rely on sagemaker.session.Session's initializer to fill in these headers for us, since we want to pass 45 # our own sagemaker_client and sagemaker_runtime_client when creating a sagemaker.session.Session object. 46 # When you pass these to the initializer, the python SDK code for constructing a botocore config with the SDK 47 # headers won't get run. 48 sagemaker_python_sdk_headers = get_user_agent_extra_suffix() 49 return ( 50 sagemaker_python_sdk_headers 51 if os.getenv(DISABLE_FMEVAL_TELEMETRY) 52 else f"{sagemaker_python_sdk_headers} lib/fmeval#{get_fmeval_package_version()}" 53 )
Return a string containing various user-agent headers to be passed to a botocore config.
This string will always contain SageMaker Python SDK headers obtained using the determine_prefix utility function from sagemaker.user_agent. If fmeval telemetry is enabled, this string will additionally contain an fmeval-specific header.
Returns
A string to be used as the user_agent_extra parameter in a botocore config.
56def get_boto_session( 57 boto_retry_mode: Literal["legacy", "standard", "adaptive"], 58 retry_attempts: int, 59) -> boto3.session.Session: 60 """ 61 Get boto3 session with adaptive retry config 62 :return: The new session 63 """ 64 botocore_session: botocore.session.Session = botocore.session.get_session() 65 botocore_session.set_default_client_config( 66 botocore.config.Config( 67 # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html 68 retries={"mode": boto_retry_mode, "max_attempts": retry_attempts}, 69 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 70 user_agent_extra=get_user_agent_extra(), 71 ) 72 ) 73 return boto3.session.Session(botocore_session=botocore_session)
Get boto3 session with adaptive retry config
Returns
The new session
76def get_sagemaker_session( 77 boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive", 78 retry_attempts: int = 10, 79) -> sagemaker.Session: 80 """ 81 Get SageMaker session with adaptive retry config. 82 :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive). 83 :param retry_attempts: max retry attempts used for botocore client failures 84 :return: The new session 85 """ 86 boto_session = get_boto_session(boto_retry_mode, retry_attempts) 87 sagemaker_service_endpoint_url = os.getenv(SAGEMAKER_SERVICE_ENDPOINT_URL) 88 sagemaker_runtime_endpoint_url = os.getenv(SAGEMAKER_RUNTIME_ENDPOINT_URL) 89 sagemaker_client = boto_session.client( 90 service_name="sagemaker", 91 endpoint_url=sagemaker_service_endpoint_url, 92 ) 93 sagemaker_runtime_client = boto_session.client( 94 service_name="sagemaker-runtime", 95 endpoint_url=sagemaker_runtime_endpoint_url, 96 ) 97 sagemaker_session = sagemaker.session.Session( 98 boto_session=boto_session, 99 sagemaker_client=sagemaker_client, 100 sagemaker_runtime_client=sagemaker_runtime_client, 101 ) 102 return sagemaker_session
Get SageMaker session with adaptive retry config.
Parameters
- boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
- retry_attempts: max retry attempts used for botocore client failures
Returns
The new session
105def get_bedrock_runtime_client( 106 boto_retry_mode: Literal["legacy", "standard", "adaptive"] = "adaptive", 107 retry_attempts: int = 10, 108) -> BedrockClient: 109 """ 110 Get Bedrock runtime client with adaptive retry config. 111 :param boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive). 112 :param retry_attempts: max retry attempts used for botocore client failures 113 :return: The new session 114 """ 115 boto_session = get_boto_session(boto_retry_mode, retry_attempts) 116 bedrock_runtime_client = boto_session.client(service_name="bedrock-runtime") 117 return bedrock_runtime_client
Get Bedrock runtime client with adaptive retry config.
Parameters
- boto_retry_mode: retry mode used for botocore config (legacy/standard/adaptive).
- retry_attempts: max retry attempts used for botocore client failures
Returns
The new session
120def is_endpoint_in_service( 121 sagemaker_session: sagemaker.session.Session, 122 endpoint_name: str, 123) -> bool: 124 """ 125 :param sagemaker_session: SageMaker session to be reused. 126 :param endpoint_name: SageMaker endpoint name. 127 :return: Whether the endpoint is in service 128 """ 129 in_service = True 130 desc = sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) 131 if not desc or "EndpointStatus" not in desc or desc["EndpointStatus"] != "InService": 132 in_service = False 133 return in_service
Parameters
- sagemaker_session: SageMaker session to be reused.
- endpoint_name: SageMaker endpoint name.
Returns
Whether the endpoint is in service
136def is_text_embedding_js_model(jumpstart_model_id: str) -> bool: 137 """ 138 :param jumpstart_model_id: JumpStart model id. 139 :return: Whether the provided model id is text embedding model or not. 140 """ 141 text_embedding_models = list_jumpstart_models("search_keywords includes Text Embedding") 142 return jumpstart_model_id in text_embedding_models
Parameters
- jumpstart_model_id: JumpStart model id.
Returns
Whether the provided model id is text embedding model or not.
145def is_proprietary_js_model(region: str, jumpstart_model_id: str) -> bool: 146 """ 147 :param region: Region of the JumpStart bucket. 148 :param jumpstart_model_id: JumpStart model id. 149 :return: Whether the provided model id is proprietary model or not. 150 """ 151 jumpstart_bucket_base_url = os.environ.get( 152 JUMPSTART_BUCKET_BASE_URL_FORMAT_ENV_VAR, JUMPSTART_BUCKET_BASE_URL_FORMAT 153 ).format(region, region) 154 proprietary_url = "{}/{}".format(jumpstart_bucket_base_url, PROPRIETARY_SDK_MANIFEST_FILE) 155 156 with request.urlopen(proprietary_url) as f: 157 proprietary_models_manifest = f.read().decode("utf-8") 158 159 model = seq(json.loads(proprietary_models_manifest)).find(lambda x: x.get(MODEL_ID, None) == jumpstart_model_id) 160 161 return model is not None
Parameters
- region: Region of the JumpStart bucket.
- jumpstart_model_id: JumpStart model id.
Returns
Whether the provided model id is proprietary model or not.