fmeval.model_runners.sm_jumpstart_model_runner
Module to manage model runners for SageMaker Endpoints with JumpStart LLMs.
1""" 2Module to manage model runners for SageMaker Endpoints with JumpStart LLMs. 3""" 4import logging 5 6import sagemaker 7from typing import Optional, Tuple, List, Union 8 9from sagemaker.jumpstart.enums import JumpStartModelType 10 11import fmeval.util as util 12from fmeval.constants import MIME_TYPE_JSON 13from fmeval.exceptions import EvalAlgorithmClientError 14from fmeval.model_runners.model_runner import ModelRunner 15from fmeval.model_runners.util import ( 16 get_sagemaker_session, 17 is_endpoint_in_service, 18 is_proprietary_js_model, 19 is_text_embedding_js_model, 20) 21 22logger = logging.getLogger(__name__) 23 24 25class JumpStartModelRunner(ModelRunner): 26 """ 27 A class to manage the creation and deletion of a SageMaker Jumpstart model runner 28 when the user provides an endpoint name corresponding to a SageMaker Endpoint 29 for a JumpStart LLM. 30 """ 31 32 def __init__( 33 self, 34 endpoint_name: str, 35 model_id: str, 36 content_template: Optional[str] = None, 37 model_version: Optional[str] = "*", 38 custom_attributes: Optional[str] = None, 39 output: Optional[str] = None, 40 log_probability: Optional[str] = None, 41 embedding: Optional[str] = None, 42 component_name: Optional[str] = None, 43 ): 44 """ 45 :param endpoint_name: Name of the SageMaker endpoint to be used for model predictions 46 :param model_id: Identifier of the SageMaker Jumpstart model 47 :param content_template: String template to compose the model input from the prompt 48 :param model_version: Version of the SageMaker Jumpstart model 49 :param custom_attributes: String that contains the custom attributes to be passed to 50 SageMaker endpoint invocation 51 :param output: JMESPath expression of output in the model output 52 :param log_probability: JMESPath expression of log probability in the model output 53 :param embedding: JMESPath expression of embedding in the model output 54 :param component_name: Name of the Amazon SageMaker inference component corresponding 55 the predictor 56 """ 57 sagemaker_session = get_sagemaker_session() 58 util.require( 59 is_endpoint_in_service(sagemaker_session, endpoint_name), 60 f"Endpoint {endpoint_name} is not in service", 61 ) 62 # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj 63 jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS 64 if is_proprietary_js_model(sagemaker_session.boto_region_name, model_id): 65 jumpstart_model_type = JumpStartModelType.PROPRIETARY 66 is_text_embedding_model = is_text_embedding_js_model(model_id) 67 68 super().__init__( 69 content_template=content_template, 70 output=output, 71 log_probability=log_probability, 72 embedding=embedding, 73 content_type=MIME_TYPE_JSON, 74 accept_type=MIME_TYPE_JSON, 75 jumpstart_model_id=model_id, 76 jumpstart_model_version=model_version, 77 jumpstart_model_type=jumpstart_model_type, 78 is_embedding_model=is_text_embedding_model, 79 ) 80 self._endpoint_name = endpoint_name 81 self._model_id = model_id 82 self._content_template = content_template 83 self._model_version = model_version 84 self._custom_attributes = custom_attributes 85 self._output = output 86 self._log_probability = log_probability 87 self._embedding = embedding 88 self._component_name = component_name 89 self._is_embedding_model = is_text_embedding_model 90 91 predictor = sagemaker.predictor.retrieve_default( 92 endpoint_name=self._endpoint_name, 93 model_id=self._model_id, 94 model_type=jumpstart_model_type, 95 model_version=self._model_version, 96 sagemaker_session=sagemaker_session, 97 ) 98 util.require(predictor.accept == MIME_TYPE_JSON, f"Model accept type `{predictor.accept}` is not supported.") 99 self._predictor = predictor 100 101 def predict(self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]: 102 """ 103 Invoke the SageMaker endpoint and parse the model response. 104 :param prompt: Input data for which you want the model to provide inference. 105 """ 106 composed_data = self._composer.compose(prompt) 107 model_output = self._predictor.predict( 108 data=composed_data, 109 custom_attributes=self._custom_attributes, 110 component_name=self._component_name, 111 ) 112 # expect embedding from all text embedding models, return directly 113 if self._is_embedding_model: 114 embedding = self._extractor.extract_embedding(data=model_output, num_records=1) 115 return embedding 116 # expect output from all model responses in JS 117 output = self._extractor.extract_output(data=model_output, num_records=1) 118 log_probability = None 119 try: 120 log_probability = self._extractor.extract_log_probability(data=model_output, num_records=1) 121 except EvalAlgorithmClientError as e: 122 # log_probability may be missing 123 logger.warning(f"Unable to fetch log_probability from model response: {e}") 124 return output, log_probability 125 126 def __reduce__(self): 127 """ 128 Custom serializer method used by Ray when it serializes instances of this 129 class in eval_algorithms.util.generate_model_predict_response_for_dataset. 130 """ 131 serialized_data = ( 132 self._endpoint_name, 133 self._model_id, 134 self._content_template, 135 self._model_version, 136 self._custom_attributes, 137 self._output, 138 self._log_probability, 139 self._embedding, 140 self._component_name, 141 ) 142 return self.__class__, serialized_data
logger =
<Logger fmeval.model_runners.sm_jumpstart_model_runner (INFO)>
26class JumpStartModelRunner(ModelRunner): 27 """ 28 A class to manage the creation and deletion of a SageMaker Jumpstart model runner 29 when the user provides an endpoint name corresponding to a SageMaker Endpoint 30 for a JumpStart LLM. 31 """ 32 33 def __init__( 34 self, 35 endpoint_name: str, 36 model_id: str, 37 content_template: Optional[str] = None, 38 model_version: Optional[str] = "*", 39 custom_attributes: Optional[str] = None, 40 output: Optional[str] = None, 41 log_probability: Optional[str] = None, 42 embedding: Optional[str] = None, 43 component_name: Optional[str] = None, 44 ): 45 """ 46 :param endpoint_name: Name of the SageMaker endpoint to be used for model predictions 47 :param model_id: Identifier of the SageMaker Jumpstart model 48 :param content_template: String template to compose the model input from the prompt 49 :param model_version: Version of the SageMaker Jumpstart model 50 :param custom_attributes: String that contains the custom attributes to be passed to 51 SageMaker endpoint invocation 52 :param output: JMESPath expression of output in the model output 53 :param log_probability: JMESPath expression of log probability in the model output 54 :param embedding: JMESPath expression of embedding in the model output 55 :param component_name: Name of the Amazon SageMaker inference component corresponding 56 the predictor 57 """ 58 sagemaker_session = get_sagemaker_session() 59 util.require( 60 is_endpoint_in_service(sagemaker_session, endpoint_name), 61 f"Endpoint {endpoint_name} is not in service", 62 ) 63 # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj 64 jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS 65 if is_proprietary_js_model(sagemaker_session.boto_region_name, model_id): 66 jumpstart_model_type = JumpStartModelType.PROPRIETARY 67 is_text_embedding_model = is_text_embedding_js_model(model_id) 68 69 super().__init__( 70 content_template=content_template, 71 output=output, 72 log_probability=log_probability, 73 embedding=embedding, 74 content_type=MIME_TYPE_JSON, 75 accept_type=MIME_TYPE_JSON, 76 jumpstart_model_id=model_id, 77 jumpstart_model_version=model_version, 78 jumpstart_model_type=jumpstart_model_type, 79 is_embedding_model=is_text_embedding_model, 80 ) 81 self._endpoint_name = endpoint_name 82 self._model_id = model_id 83 self._content_template = content_template 84 self._model_version = model_version 85 self._custom_attributes = custom_attributes 86 self._output = output 87 self._log_probability = log_probability 88 self._embedding = embedding 89 self._component_name = component_name 90 self._is_embedding_model = is_text_embedding_model 91 92 predictor = sagemaker.predictor.retrieve_default( 93 endpoint_name=self._endpoint_name, 94 model_id=self._model_id, 95 model_type=jumpstart_model_type, 96 model_version=self._model_version, 97 sagemaker_session=sagemaker_session, 98 ) 99 util.require(predictor.accept == MIME_TYPE_JSON, f"Model accept type `{predictor.accept}` is not supported.") 100 self._predictor = predictor 101 102 def predict(self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]: 103 """ 104 Invoke the SageMaker endpoint and parse the model response. 105 :param prompt: Input data for which you want the model to provide inference. 106 """ 107 composed_data = self._composer.compose(prompt) 108 model_output = self._predictor.predict( 109 data=composed_data, 110 custom_attributes=self._custom_attributes, 111 component_name=self._component_name, 112 ) 113 # expect embedding from all text embedding models, return directly 114 if self._is_embedding_model: 115 embedding = self._extractor.extract_embedding(data=model_output, num_records=1) 116 return embedding 117 # expect output from all model responses in JS 118 output = self._extractor.extract_output(data=model_output, num_records=1) 119 log_probability = None 120 try: 121 log_probability = self._extractor.extract_log_probability(data=model_output, num_records=1) 122 except EvalAlgorithmClientError as e: 123 # log_probability may be missing 124 logger.warning(f"Unable to fetch log_probability from model response: {e}") 125 return output, log_probability 126 127 def __reduce__(self): 128 """ 129 Custom serializer method used by Ray when it serializes instances of this 130 class in eval_algorithms.util.generate_model_predict_response_for_dataset. 131 """ 132 serialized_data = ( 133 self._endpoint_name, 134 self._model_id, 135 self._content_template, 136 self._model_version, 137 self._custom_attributes, 138 self._output, 139 self._log_probability, 140 self._embedding, 141 self._component_name, 142 ) 143 return self.__class__, serialized_data
A class to manage the creation and deletion of a SageMaker Jumpstart model runner when the user provides an endpoint name corresponding to a SageMaker Endpoint for a JumpStart LLM.
JumpStartModelRunner( endpoint_name: str, model_id: str, content_template: Optional[str] = None, model_version: Optional[str] = '*', custom_attributes: Optional[str] = None, output: Optional[str] = None, log_probability: Optional[str] = None, embedding: Optional[str] = None, component_name: Optional[str] = None)
33 def __init__( 34 self, 35 endpoint_name: str, 36 model_id: str, 37 content_template: Optional[str] = None, 38 model_version: Optional[str] = "*", 39 custom_attributes: Optional[str] = None, 40 output: Optional[str] = None, 41 log_probability: Optional[str] = None, 42 embedding: Optional[str] = None, 43 component_name: Optional[str] = None, 44 ): 45 """ 46 :param endpoint_name: Name of the SageMaker endpoint to be used for model predictions 47 :param model_id: Identifier of the SageMaker Jumpstart model 48 :param content_template: String template to compose the model input from the prompt 49 :param model_version: Version of the SageMaker Jumpstart model 50 :param custom_attributes: String that contains the custom attributes to be passed to 51 SageMaker endpoint invocation 52 :param output: JMESPath expression of output in the model output 53 :param log_probability: JMESPath expression of log probability in the model output 54 :param embedding: JMESPath expression of embedding in the model output 55 :param component_name: Name of the Amazon SageMaker inference component corresponding 56 the predictor 57 """ 58 sagemaker_session = get_sagemaker_session() 59 util.require( 60 is_endpoint_in_service(sagemaker_session, endpoint_name), 61 f"Endpoint {endpoint_name} is not in service", 62 ) 63 # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj 64 jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS 65 if is_proprietary_js_model(sagemaker_session.boto_region_name, model_id): 66 jumpstart_model_type = JumpStartModelType.PROPRIETARY 67 is_text_embedding_model = is_text_embedding_js_model(model_id) 68 69 super().__init__( 70 content_template=content_template, 71 output=output, 72 log_probability=log_probability, 73 embedding=embedding, 74 content_type=MIME_TYPE_JSON, 75 accept_type=MIME_TYPE_JSON, 76 jumpstart_model_id=model_id, 77 jumpstart_model_version=model_version, 78 jumpstart_model_type=jumpstart_model_type, 79 is_embedding_model=is_text_embedding_model, 80 ) 81 self._endpoint_name = endpoint_name 82 self._model_id = model_id 83 self._content_template = content_template 84 self._model_version = model_version 85 self._custom_attributes = custom_attributes 86 self._output = output 87 self._log_probability = log_probability 88 self._embedding = embedding 89 self._component_name = component_name 90 self._is_embedding_model = is_text_embedding_model 91 92 predictor = sagemaker.predictor.retrieve_default( 93 endpoint_name=self._endpoint_name, 94 model_id=self._model_id, 95 model_type=jumpstart_model_type, 96 model_version=self._model_version, 97 sagemaker_session=sagemaker_session, 98 ) 99 util.require(predictor.accept == MIME_TYPE_JSON, f"Model accept type `{predictor.accept}` is not supported.") 100 self._predictor = predictor
Parameters
- endpoint_name: Name of the SageMaker endpoint to be used for model predictions
- model_id: Identifier of the SageMaker Jumpstart model
- content_template: String template to compose the model input from the prompt
- model_version: Version of the SageMaker Jumpstart model
- custom_attributes: String that contains the custom attributes to be passed to SageMaker endpoint invocation
- output: JMESPath expression of output in the model output
- log_probability: JMESPath expression of log probability in the model output
- embedding: JMESPath expression of embedding in the model output
- component_name: Name of the Amazon SageMaker inference component corresponding the predictor
def
predict( self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]:
102 def predict(self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]: 103 """ 104 Invoke the SageMaker endpoint and parse the model response. 105 :param prompt: Input data for which you want the model to provide inference. 106 """ 107 composed_data = self._composer.compose(prompt) 108 model_output = self._predictor.predict( 109 data=composed_data, 110 custom_attributes=self._custom_attributes, 111 component_name=self._component_name, 112 ) 113 # expect embedding from all text embedding models, return directly 114 if self._is_embedding_model: 115 embedding = self._extractor.extract_embedding(data=model_output, num_records=1) 116 return embedding 117 # expect output from all model responses in JS 118 output = self._extractor.extract_output(data=model_output, num_records=1) 119 log_probability = None 120 try: 121 log_probability = self._extractor.extract_log_probability(data=model_output, num_records=1) 122 except EvalAlgorithmClientError as e: 123 # log_probability may be missing 124 logger.warning(f"Unable to fetch log_probability from model response: {e}") 125 return output, log_probability
Invoke the SageMaker endpoint and parse the model response.
Parameters
- prompt: Input data for which you want the model to provide inference.