"""@private"""
import re
from typing import Any, Dict, List, Literal, Optional
# NOTE ON DEPENDENCIES:
# - since Jan 2024, there is https://pypi.org/project/langchain-openai/ which is a separate package and imports openai models.
# Decided to not make this a dependency of langfuse as few people will have this. Need to match these models manually
# - langchain_community is loaded as a dependency of langchain, but extremely unreliable. Decided to not depend on it.
def _extract_model_name(
serialized: Dict[str, Any],
**kwargs: Any,
):
"""Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse."""
# In this function we return on the first match, so the order of operations is important
# First, extract known models where we know the model name aka id
# Extract the model name from the provided path (aray) in the serialized or kwargs object
models_by_id = [
("ChatGoogleGenerativeAI", ["kwargs", "model"], "serialized"),
("ChatMistralAI", ["kwargs", "model"], "serialized"),
("ChatVertexAi", ["kwargs", "model_name"], "serialized"),
("ChatVertexAI", ["kwargs", "model_name"], "serialized"),
("OpenAI", ["invocation_params", "model_name"], "kwargs"),
("ChatOpenAI", ["invocation_params", "model_name"], "kwargs"),
("AzureChatOpenAI", ["invocation_params", "model"], "kwargs"),
("AzureChatOpenAI", ["invocation_params", "model_name"], "kwargs"),
("HuggingFacePipeline", ["invocation_params", "model_id"], "kwargs"),
("BedrockChat", ["kwargs", "model_id"], "serialized"),
("Bedrock", ["kwargs", "model_id"], "serialized"),
("ChatBedrock", ["kwargs", "model_id"], "serialized"),
("LlamaCpp", ["invocation_params", "model_path"], "kwargs"),
]
for model_name, keys, select_from in models_by_id:
model = _extract_model_by_path_for_id(
model_name, serialized, kwargs, keys, select_from
)
if model:
return model
# Second, we match AzureOpenAI as we need to extract the model name, fdeployment version and deployment name
if serialized.get("id")[-1] == "AzureOpenAI":
if kwargs.get("invocation_params").get("model_name"):
return kwargs.get("invocation_params").get("model_name")
deployment_name = None
if serialized.get("kwargs").get("openai_api_version"):
deployment_name = serialized.get("kwargs").get("deployment_version")
deployment_version = None
if serialized.get("kwargs").get("deployment_name"):
deployment_name = serialized.get("kwargs").get("deployment_name")
return deployment_name + "-" + deployment_version
# Third, for some models, we are unable to extract the model by a path in an object. Langfuse provides us with a string representation of the model pbjects
# We use regex to extract the model from the repr string
models_by_pattern = [
("Anthropic", "model", "anthropic"),
("ChatAnthropic", "model", None),
("ChatTongyi", "model_name", None),
("ChatCohere", "model", None),
("Cohere", "model", None),
("HuggingFaceHub", "model", None),
("ChatAnyscale", "model_name", None),
("TextGen", "model", "text-gen"),
("Ollama", "model", None),
("ChatOllama", "model", None),
("ChatFireworks", "model", None),
("ChatPerplexity", "model", None),
("VLLM", "model", None),
("Xinference", "model_uid", None),
]
for model_name, pattern, default in models_by_pattern:
model = _extract_model_from_repr_by_pattern(
model_name, serialized, pattern, default
)
if model:
return model
# Finally, we try to extract the most likely paths as a catch all
random_paths = [
["kwargs", "model_name"],
["kwargs", "model"],
["invocation_params", "model_name"],
["invocation_params", "model"],
]
for select in ["kwargs", "serialized"]:
for path in random_paths:
model = _extract_model_by_path(serialized, kwargs, path, select)
if model:
return model
return None
def _extract_model_from_repr_by_pattern(
id: str, serialized: dict, pattern: str, default: Optional[str] = None
):
if serialized.get("id")[-1] == id:
if serialized.get("repr"):
extracted = _extract_model_with_regex(pattern, serialized.get("repr"))
return extracted if extracted else default if default else None
def _extract_model_with_regex(pattern: str, text: str):
match = re.search(rf"{pattern}='(.*?)'", text)
if match:
return match.group(1)
return None
def _extract_model_by_path_for_id(
id: str,
serialized: dict,
kwargs: dict,
keys: List[str],
select_from: str = Literal["serialized", "kwargs"],
):
if serialized.get("id")[-1] == id:
return _extract_model_by_path(serialized, kwargs, keys, select_from)
def _extract_model_by_path(
serialized: dict,
kwargs: dict,
keys: List[str],
select_from: str = Literal["serialized", "kwargs"],
):
current_obj = kwargs if select_from == "kwargs" else serialized
for key in keys:
current_obj = current_obj.get(key)
if not current_obj:
return None
return current_obj if current_obj else None