fmeval.model_runners.sm_jumpstart_model_runner

Module to manage model runners for SageMaker Endpoints with JumpStart LLMs.

  1"""
  2Module to manage model runners for SageMaker Endpoints with JumpStart LLMs.
  3"""
  4import logging
  5
  6import sagemaker
  7from typing import Optional, Tuple, List, Union
  8
  9from sagemaker.jumpstart.enums import JumpStartModelType
 10
 11import fmeval.util as util
 12from fmeval.constants import MIME_TYPE_JSON
 13from fmeval.exceptions import EvalAlgorithmClientError
 14from fmeval.model_runners.model_runner import ModelRunner
 15from fmeval.model_runners.util import (
 16    get_sagemaker_session,
 17    is_endpoint_in_service,
 18    is_proprietary_js_model,
 19    is_text_embedding_js_model,
 20)
 21
 22logger = logging.getLogger(__name__)
 23
 24
 25class JumpStartModelRunner(ModelRunner):
 26    """
 27    A class to manage the creation and deletion of a SageMaker Jumpstart model runner
 28    when the user provides an endpoint name corresponding to a SageMaker Endpoint
 29    for a JumpStart LLM.
 30    """
 31
 32    def __init__(
 33        self,
 34        endpoint_name: str,
 35        model_id: str,
 36        content_template: Optional[str] = None,
 37        model_version: Optional[str] = "*",
 38        custom_attributes: Optional[str] = None,
 39        output: Optional[str] = None,
 40        log_probability: Optional[str] = None,
 41        embedding: Optional[str] = None,
 42        component_name: Optional[str] = None,
 43    ):
 44        """
 45        :param endpoint_name: Name of the SageMaker endpoint to be used for model predictions
 46        :param model_id: Identifier of the SageMaker Jumpstart model
 47        :param content_template: String template to compose the model input from the prompt
 48        :param model_version: Version of the SageMaker Jumpstart model
 49        :param custom_attributes: String that contains the custom attributes to be passed to
 50                                  SageMaker endpoint invocation
 51        :param output: JMESPath expression of output in the model output
 52        :param log_probability: JMESPath expression of log probability in the model output
 53        :param embedding: JMESPath expression of embedding in the model output
 54        :param component_name: Name of the Amazon SageMaker inference component corresponding
 55                            the predictor
 56        """
 57        sagemaker_session = get_sagemaker_session()
 58        util.require(
 59            is_endpoint_in_service(sagemaker_session, endpoint_name),
 60            f"Endpoint {endpoint_name} is not in service",
 61        )
 62        # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj
 63        jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS
 64        if is_proprietary_js_model(sagemaker_session.boto_region_name, model_id):
 65            jumpstart_model_type = JumpStartModelType.PROPRIETARY
 66        is_text_embedding_model = is_text_embedding_js_model(model_id)
 67
 68        super().__init__(
 69            content_template=content_template,
 70            output=output,
 71            log_probability=log_probability,
 72            embedding=embedding,
 73            content_type=MIME_TYPE_JSON,
 74            accept_type=MIME_TYPE_JSON,
 75            jumpstart_model_id=model_id,
 76            jumpstart_model_version=model_version,
 77            jumpstart_model_type=jumpstart_model_type,
 78            is_embedding_model=is_text_embedding_model,
 79        )
 80        self._endpoint_name = endpoint_name
 81        self._model_id = model_id
 82        self._content_template = content_template
 83        self._model_version = model_version
 84        self._custom_attributes = custom_attributes
 85        self._output = output
 86        self._log_probability = log_probability
 87        self._embedding = embedding
 88        self._component_name = component_name
 89        self._is_embedding_model = is_text_embedding_model
 90
 91        predictor = sagemaker.predictor.retrieve_default(
 92            endpoint_name=self._endpoint_name,
 93            model_id=self._model_id,
 94            model_type=jumpstart_model_type,
 95            model_version=self._model_version,
 96            sagemaker_session=sagemaker_session,
 97        )
 98        util.require(predictor.accept == MIME_TYPE_JSON, f"Model accept type `{predictor.accept}` is not supported.")
 99        self._predictor = predictor
100
101    def predict(self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]:
102        """
103        Invoke the SageMaker endpoint and parse the model response.
104        :param prompt: Input data for which you want the model to provide inference.
105        """
106        composed_data = self._composer.compose(prompt)
107        model_output = self._predictor.predict(
108            data=composed_data,
109            custom_attributes=self._custom_attributes,
110            component_name=self._component_name,
111        )
112        # expect embedding from all text embedding models, return directly
113        if self._is_embedding_model:
114            embedding = self._extractor.extract_embedding(data=model_output, num_records=1)
115            return embedding
116        # expect output from all model responses in JS
117        output = self._extractor.extract_output(data=model_output, num_records=1)
118        log_probability = None
119        try:
120            log_probability = self._extractor.extract_log_probability(data=model_output, num_records=1)
121        except EvalAlgorithmClientError as e:
122            # log_probability may be missing
123            logger.warning(f"Unable to fetch log_probability from model response: {e}")
124        return output, log_probability
125
126    def __reduce__(self):
127        """
128        Custom serializer method used by Ray when it serializes instances of this
129        class in eval_algorithms.util.generate_model_predict_response_for_dataset.
130        """
131        serialized_data = (
132            self._endpoint_name,
133            self._model_id,
134            self._content_template,
135            self._model_version,
136            self._custom_attributes,
137            self._output,
138            self._log_probability,
139            self._embedding,
140            self._component_name,
141        )
142        return self.__class__, serialized_data
class JumpStartModelRunner(fmeval.model_runners.model_runner.ModelRunner):
 26class JumpStartModelRunner(ModelRunner):
 27    """
 28    A class to manage the creation and deletion of a SageMaker Jumpstart model runner
 29    when the user provides an endpoint name corresponding to a SageMaker Endpoint
 30    for a JumpStart LLM.
 31    """
 32
 33    def __init__(
 34        self,
 35        endpoint_name: str,
 36        model_id: str,
 37        content_template: Optional[str] = None,
 38        model_version: Optional[str] = "*",
 39        custom_attributes: Optional[str] = None,
 40        output: Optional[str] = None,
 41        log_probability: Optional[str] = None,
 42        embedding: Optional[str] = None,
 43        component_name: Optional[str] = None,
 44    ):
 45        """
 46        :param endpoint_name: Name of the SageMaker endpoint to be used for model predictions
 47        :param model_id: Identifier of the SageMaker Jumpstart model
 48        :param content_template: String template to compose the model input from the prompt
 49        :param model_version: Version of the SageMaker Jumpstart model
 50        :param custom_attributes: String that contains the custom attributes to be passed to
 51                                  SageMaker endpoint invocation
 52        :param output: JMESPath expression of output in the model output
 53        :param log_probability: JMESPath expression of log probability in the model output
 54        :param embedding: JMESPath expression of embedding in the model output
 55        :param component_name: Name of the Amazon SageMaker inference component corresponding
 56                            the predictor
 57        """
 58        sagemaker_session = get_sagemaker_session()
 59        util.require(
 60            is_endpoint_in_service(sagemaker_session, endpoint_name),
 61            f"Endpoint {endpoint_name} is not in service",
 62        )
 63        # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj
 64        jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS
 65        if is_proprietary_js_model(sagemaker_session.boto_region_name, model_id):
 66            jumpstart_model_type = JumpStartModelType.PROPRIETARY
 67        is_text_embedding_model = is_text_embedding_js_model(model_id)
 68
 69        super().__init__(
 70            content_template=content_template,
 71            output=output,
 72            log_probability=log_probability,
 73            embedding=embedding,
 74            content_type=MIME_TYPE_JSON,
 75            accept_type=MIME_TYPE_JSON,
 76            jumpstart_model_id=model_id,
 77            jumpstart_model_version=model_version,
 78            jumpstart_model_type=jumpstart_model_type,
 79            is_embedding_model=is_text_embedding_model,
 80        )
 81        self._endpoint_name = endpoint_name
 82        self._model_id = model_id
 83        self._content_template = content_template
 84        self._model_version = model_version
 85        self._custom_attributes = custom_attributes
 86        self._output = output
 87        self._log_probability = log_probability
 88        self._embedding = embedding
 89        self._component_name = component_name
 90        self._is_embedding_model = is_text_embedding_model
 91
 92        predictor = sagemaker.predictor.retrieve_default(
 93            endpoint_name=self._endpoint_name,
 94            model_id=self._model_id,
 95            model_type=jumpstart_model_type,
 96            model_version=self._model_version,
 97            sagemaker_session=sagemaker_session,
 98        )
 99        util.require(predictor.accept == MIME_TYPE_JSON, f"Model accept type `{predictor.accept}` is not supported.")
100        self._predictor = predictor
101
102    def predict(self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]:
103        """
104        Invoke the SageMaker endpoint and parse the model response.
105        :param prompt: Input data for which you want the model to provide inference.
106        """
107        composed_data = self._composer.compose(prompt)
108        model_output = self._predictor.predict(
109            data=composed_data,
110            custom_attributes=self._custom_attributes,
111            component_name=self._component_name,
112        )
113        # expect embedding from all text embedding models, return directly
114        if self._is_embedding_model:
115            embedding = self._extractor.extract_embedding(data=model_output, num_records=1)
116            return embedding
117        # expect output from all model responses in JS
118        output = self._extractor.extract_output(data=model_output, num_records=1)
119        log_probability = None
120        try:
121            log_probability = self._extractor.extract_log_probability(data=model_output, num_records=1)
122        except EvalAlgorithmClientError as e:
123            # log_probability may be missing
124            logger.warning(f"Unable to fetch log_probability from model response: {e}")
125        return output, log_probability
126
127    def __reduce__(self):
128        """
129        Custom serializer method used by Ray when it serializes instances of this
130        class in eval_algorithms.util.generate_model_predict_response_for_dataset.
131        """
132        serialized_data = (
133            self._endpoint_name,
134            self._model_id,
135            self._content_template,
136            self._model_version,
137            self._custom_attributes,
138            self._output,
139            self._log_probability,
140            self._embedding,
141            self._component_name,
142        )
143        return self.__class__, serialized_data

A class to manage the creation and deletion of a SageMaker Jumpstart model runner when the user provides an endpoint name corresponding to a SageMaker Endpoint for a JumpStart LLM.

JumpStartModelRunner( endpoint_name: str, model_id: str, content_template: Optional[str] = None, model_version: Optional[str] = '*', custom_attributes: Optional[str] = None, output: Optional[str] = None, log_probability: Optional[str] = None, embedding: Optional[str] = None, component_name: Optional[str] = None)
 33    def __init__(
 34        self,
 35        endpoint_name: str,
 36        model_id: str,
 37        content_template: Optional[str] = None,
 38        model_version: Optional[str] = "*",
 39        custom_attributes: Optional[str] = None,
 40        output: Optional[str] = None,
 41        log_probability: Optional[str] = None,
 42        embedding: Optional[str] = None,
 43        component_name: Optional[str] = None,
 44    ):
 45        """
 46        :param endpoint_name: Name of the SageMaker endpoint to be used for model predictions
 47        :param model_id: Identifier of the SageMaker Jumpstart model
 48        :param content_template: String template to compose the model input from the prompt
 49        :param model_version: Version of the SageMaker Jumpstart model
 50        :param custom_attributes: String that contains the custom attributes to be passed to
 51                                  SageMaker endpoint invocation
 52        :param output: JMESPath expression of output in the model output
 53        :param log_probability: JMESPath expression of log probability in the model output
 54        :param embedding: JMESPath expression of embedding in the model output
 55        :param component_name: Name of the Amazon SageMaker inference component corresponding
 56                            the predictor
 57        """
 58        sagemaker_session = get_sagemaker_session()
 59        util.require(
 60            is_endpoint_in_service(sagemaker_session, endpoint_name),
 61            f"Endpoint {endpoint_name} is not in service",
 62        )
 63        # Default model type is always OPEN_WEIGHTS. See https://tinyurl.com/yc58s6wj
 64        jumpstart_model_type = JumpStartModelType.OPEN_WEIGHTS
 65        if is_proprietary_js_model(sagemaker_session.boto_region_name, model_id):
 66            jumpstart_model_type = JumpStartModelType.PROPRIETARY
 67        is_text_embedding_model = is_text_embedding_js_model(model_id)
 68
 69        super().__init__(
 70            content_template=content_template,
 71            output=output,
 72            log_probability=log_probability,
 73            embedding=embedding,
 74            content_type=MIME_TYPE_JSON,
 75            accept_type=MIME_TYPE_JSON,
 76            jumpstart_model_id=model_id,
 77            jumpstart_model_version=model_version,
 78            jumpstart_model_type=jumpstart_model_type,
 79            is_embedding_model=is_text_embedding_model,
 80        )
 81        self._endpoint_name = endpoint_name
 82        self._model_id = model_id
 83        self._content_template = content_template
 84        self._model_version = model_version
 85        self._custom_attributes = custom_attributes
 86        self._output = output
 87        self._log_probability = log_probability
 88        self._embedding = embedding
 89        self._component_name = component_name
 90        self._is_embedding_model = is_text_embedding_model
 91
 92        predictor = sagemaker.predictor.retrieve_default(
 93            endpoint_name=self._endpoint_name,
 94            model_id=self._model_id,
 95            model_type=jumpstart_model_type,
 96            model_version=self._model_version,
 97            sagemaker_session=sagemaker_session,
 98        )
 99        util.require(predictor.accept == MIME_TYPE_JSON, f"Model accept type `{predictor.accept}` is not supported.")
100        self._predictor = predictor
Parameters
  • endpoint_name: Name of the SageMaker endpoint to be used for model predictions
  • model_id: Identifier of the SageMaker Jumpstart model
  • content_template: String template to compose the model input from the prompt
  • model_version: Version of the SageMaker Jumpstart model
  • custom_attributes: String that contains the custom attributes to be passed to SageMaker endpoint invocation
  • output: JMESPath expression of output in the model output
  • log_probability: JMESPath expression of log probability in the model output
  • embedding: JMESPath expression of embedding in the model output
  • component_name: Name of the Amazon SageMaker inference component corresponding the predictor
def predict( self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]:
102    def predict(self, prompt: str) -> Union[Tuple[Optional[str], Optional[float]], List[float]]:
103        """
104        Invoke the SageMaker endpoint and parse the model response.
105        :param prompt: Input data for which you want the model to provide inference.
106        """
107        composed_data = self._composer.compose(prompt)
108        model_output = self._predictor.predict(
109            data=composed_data,
110            custom_attributes=self._custom_attributes,
111            component_name=self._component_name,
112        )
113        # expect embedding from all text embedding models, return directly
114        if self._is_embedding_model:
115            embedding = self._extractor.extract_embedding(data=model_output, num_records=1)
116            return embedding
117        # expect output from all model responses in JS
118        output = self._extractor.extract_output(data=model_output, num_records=1)
119        log_probability = None
120        try:
121            log_probability = self._extractor.extract_log_probability(data=model_output, num_records=1)
122        except EvalAlgorithmClientError as e:
123            # log_probability may be missing
124            logger.warning(f"Unable to fetch log_probability from model response: {e}")
125        return output, log_probability

Invoke the SageMaker endpoint and parse the model response.

Parameters
  • prompt: Input data for which you want the model to provide inference.