跳到内容

Azure 推理

已废弃术语

LlamaCloud

AzureAIEmbeddingsModel #

示例

from llama_index.core import Settings
from llama_index.embeddings.azure_inference import AzureAIEmbeddingsModel

llm = AzureAIEmbeddingsModel(
    endpoint="https://[your-endpoint].inference.ai.azure.com",
    credential="your-api-key",
)

# # If using Microsoft Entra ID authentication, you can create the
# # client as follows
#
# from azure.identity import DefaultAzureCredential
#
# embed_model = AzureAIEmbeddingsModel(
#     endpoint="https://[your-endpoint].inference.ai.azure.com",
#     credential=DefaultAzureCredential()
# )
#
# # If you plan to use asynchronous calling, make sure to use the async
# # credentials as follows
#
# from azure.identity.aio import DefaultAzureCredential as DefaultAzureCredentialAsync
#
# embed_model = AzureAIEmbeddingsModel(
#     endpoint="https://[your-endpoint].inference.ai.azure.com",
#     credential=DefaultAzureCredentialAsync()
# )

# Once the client is instantiated, you can set the context to use the model
Settings.embed_model = embed_model

documents = SimpleDirectoryReader("./data").load_data()
index = VectorStoreIndex.from_documents(documents)
基础:BaseEmbedding
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class AzureAIEmbeddingsModel(BaseEmbedding):
    """
    Azure AI model inference for embeddings.

    Examples:
        ```python
        from llama_index.core import Settings
        from llama_index.embeddings.azure_inference import AzureAIEmbeddingsModel

        llm = AzureAIEmbeddingsModel(
            endpoint="https://[your-endpoint].inference.ai.azure.com",
            credential="your-api-key",
        )

        # # If using Microsoft Entra ID authentication, you can create the
        # # client as follows
        #
        # from azure.identity import DefaultAzureCredential
        #
        # embed_model = AzureAIEmbeddingsModel(
        #     endpoint="https://[your-endpoint].inference.ai.azure.com",
        #     credential=DefaultAzureCredential()
        # )
        #
        # # If you plan to use asynchronous calling, make sure to use the async
        # # credentials as follows
        #
        # from azure.identity.aio import DefaultAzureCredential as DefaultAzureCredentialAsync
        #
        # embed_model = AzureAIEmbeddingsModel(
        #     endpoint="https://[your-endpoint].inference.ai.azure.com",
        #     credential=DefaultAzureCredentialAsync()
        # )

        # Once the client is instantiated, you can set the context to use the model
        Settings.embed_model = embed_model

        documents = SimpleDirectoryReader("./data").load_data()
        index = VectorStoreIndex.from_documents(documents)
        ```

    """

    model_kwargs: Dict[str, Any] = Field(
        default_factory=dict, description="Additional kwargs model parameters."
    )

    _client: EmbeddingsClient = PrivateAttr()
    _async_client: EmbeddingsClientAsync = PrivateAttr()

    def __init__(
        self,
        endpoint: str = None,
        credential: Union[str, AzureKeyCredential, "TokenCredential"] = None,
        model_name: str = None,
        api_version: str = None,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
        num_workers: Optional[int] = None,
        client_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ):
        client_kwargs = client_kwargs or {}

        endpoint = get_from_param_or_env(
            "endpoint", endpoint, "AZURE_INFERENCE_ENDPOINT", None
        )
        credential = get_from_param_or_env(
            "credential", credential, "AZURE_INFERENCE_CREDENTIAL", None
        )
        credential = (
            AzureKeyCredential(credential)
            if isinstance(credential, str)
            else credential
        )

        if not endpoint:
            raise ValueError(
                "You must provide an endpoint to use the Azure AI model inference LLM."
                "Pass the endpoint as a parameter or set the AZURE_INFERENCE_ENDPOINT"
                "environment variable."
            )

        if not credential:
            raise ValueError(
                "You must provide an credential to use the Azure AI model inference LLM."
                "Pass the credential as a parameter or set the AZURE_INFERENCE_CREDENTIAL"
            )

        if api_version:
            client_kwargs["api_version"] = api_version

        client = EmbeddingsClient(
            endpoint=endpoint,
            credential=credential,
            user_agent="llamaindex",
            **client_kwargs,
        )

        async_client = EmbeddingsClientAsync(
            endpoint=endpoint,
            credential=credential,
            user_agent="llamaindex",
            **client_kwargs,
        )

        if not model_name:
            try:
                # Get model info from the endpoint. This method may not be supported by all
                # endpoints.
                model_info = client.get_model_info()
                model_name = model_info.get("model_name", None)
            except HttpResponseError:
                logger.warning(
                    f"Endpoint '{self._client._config.endpoint}' does not support model metadata retrieval. "
                    "Unable to populate model attributes."
                )

        super().__init__(
            model_name=model_name or "unknown",
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            num_workers=num_workers,
            **kwargs,
        )

        self._client = client
        self._async_client = async_client

    @classmethod
    def class_name(cls) -> str:
        return "AzureAIEmbeddingsModel"

    @property
    def _model_kwargs(self) -> Dict[str, Any]:
        additional_kwargs = {}
        if self.model_name and self.model_name != "unknown":
            additional_kwargs["model"] = self.model_name
        if self.model_kwargs:
            # pass any extra model parameters
            additional_kwargs.update(self.model_kwargs)

        return additional_kwargs

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get query embedding."""
        return self._client.embed(input=[query], **self._model_kwargs).data[0].embedding

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """The asynchronous version of _get_query_embedding."""
        return (
            (await self._async_client.embed(input=[query], **self._model_kwargs))
            .data[0]
            .embedding
        )

    def _get_text_embedding(self, text: str) -> List[float]:
        """Get text embedding."""
        return self._client.embed(input=[text], **self._model_kwargs).data[0].embedding

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """Asynchronously get text embedding."""
        return (
            (await self._async_client.embed(input=[text], **self._model_kwargs))
            .data[0]
            .embedding
        )

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Get text embeddings."""
        embedding_response = self._client.embed(input=texts, **self._model_kwargs).data
        return [embed.embedding for embed in embedding_response]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Asynchronously get text embeddings."""
        embedding_response = await self._async_client.embed(
            input=texts, **self._model_kwargs
        )
        return [embed.embedding for embed in embedding_response.data]