Skip to content

pydantic_ai.models.fallback

FallbackModel dataclass

Bases: Model

A model that uses one or more fallback models upon failure.

Apart from __init__, all methods are private or match those of the base class.

Source code in pydantic_ai_slim/pydantic_ai/models/fallback.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@dataclass(init=False)
class FallbackModel(Model):
    """A model that uses one or more fallback models upon failure.

    Apart from `__init__`, all methods are private or match those of the base class.
    """

    models: list[Model]

    _model_name: str = field(repr=False)
    _fallback_on: Callable[[Exception], bool]

    def __init__(
        self,
        default_model: Model | KnownModelName,
        *fallback_models: Model | KnownModelName,
        fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
    ):
        """Initialize a fallback model instance.

        Args:
            default_model: The name or instance of the default model to use.
            fallback_models: The names or instances of the fallback models to use upon failure.
            fallback_on: A callable or tuple of exceptions that should trigger a fallback.
        """
        super().__init__()
        self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]

        if isinstance(fallback_on, tuple):
            self._fallback_on = _default_fallback_condition_factory(fallback_on)
        else:
            self._fallback_on = fallback_on

    async def request(
        self,
        messages: list[ModelMessage],
        model_settings: ModelSettings | None,
        model_request_parameters: ModelRequestParameters,
    ) -> ModelResponse:
        """Try each model in sequence until one succeeds.

        In case of failure, raise a FallbackExceptionGroup with all exceptions.
        """
        exceptions: list[Exception] = []

        for model in self.models:
            customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
            try:
                response = await model.request(messages, model_settings, customized_model_request_parameters)
            except Exception as exc:
                if self._fallback_on(exc):
                    exceptions.append(exc)
                    continue
                raise exc

            self._set_span_attributes(model)
            return response

        raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)

    @asynccontextmanager
    async def request_stream(
        self,
        messages: list[ModelMessage],
        model_settings: ModelSettings | None,
        model_request_parameters: ModelRequestParameters,
        run_context: RunContext[Any] | None = None,
    ) -> AsyncIterator[StreamedResponse]:
        """Try each model in sequence until one succeeds."""
        exceptions: list[Exception] = []

        for model in self.models:
            customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
            async with AsyncExitStack() as stack:
                try:
                    response = await stack.enter_async_context(
                        model.request_stream(messages, model_settings, customized_model_request_parameters, run_context)
                    )
                except Exception as exc:
                    if self._fallback_on(exc):
                        exceptions.append(exc)
                        continue
                    raise exc  # pragma: no cover

                self._set_span_attributes(model)
                yield response
                return

        raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)

    def _set_span_attributes(self, model: Model):
        with suppress(Exception):
            span = get_current_span()
            if span.is_recording():
                attributes = getattr(span, 'attributes', {})
                if attributes.get('gen_ai.request.model') == self.model_name:  # pragma: no branch
                    span.set_attributes(InstrumentedModel.model_attributes(model))

    @property
    def model_name(self) -> str:
        """The model name."""
        return f'fallback:{",".join(model.model_name for model in self.models)}'

    @property
    def system(self) -> str:
        return f'fallback:{",".join(model.system for model in self.models)}'

    @property
    def base_url(self) -> str | None:
        return self.models[0].base_url

__init__

__init__(
    default_model: Model | KnownModelName,
    *fallback_models: Model | KnownModelName,
    fallback_on: (
        Callable[[Exception], bool]
        | tuple[type[Exception], ...]
    ) = (ModelHTTPError,)
)

Initialize a fallback model instance.

Parameters:

Name Type Description Default
default_model Model | KnownModelName

The name or instance of the default model to use.

required
fallback_models Model | KnownModelName

The names or instances of the fallback models to use upon failure.

()
fallback_on Callable[[Exception], bool] | tuple[type[Exception], ...]

A callable or tuple of exceptions that should trigger a fallback.

(ModelHTTPError,)
Source code in pydantic_ai_slim/pydantic_ai/models/fallback.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    default_model: Model | KnownModelName,
    *fallback_models: Model | KnownModelName,
    fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
):
    """Initialize a fallback model instance.

    Args:
        default_model: The name or instance of the default model to use.
        fallback_models: The names or instances of the fallback models to use upon failure.
        fallback_on: A callable or tuple of exceptions that should trigger a fallback.
    """
    super().__init__()
    self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]

    if isinstance(fallback_on, tuple):
        self._fallback_on = _default_fallback_condition_factory(fallback_on)
    else:
        self._fallback_on = fallback_on

request async

request(
    messages: list[ModelMessage],
    model_settings: ModelSettings | None,
    model_request_parameters: ModelRequestParameters,
) -> ModelResponse

Try each model in sequence until one succeeds.

In case of failure, raise a FallbackExceptionGroup with all exceptions.

Source code in pydantic_ai_slim/pydantic_ai/models/fallback.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
async def request(
    self,
    messages: list[ModelMessage],
    model_settings: ModelSettings | None,
    model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
    """Try each model in sequence until one succeeds.

    In case of failure, raise a FallbackExceptionGroup with all exceptions.
    """
    exceptions: list[Exception] = []

    for model in self.models:
        customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
        try:
            response = await model.request(messages, model_settings, customized_model_request_parameters)
        except Exception as exc:
            if self._fallback_on(exc):
                exceptions.append(exc)
                continue
            raise exc

        self._set_span_attributes(model)
        return response

    raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)

request_stream async

request_stream(
    messages: list[ModelMessage],
    model_settings: ModelSettings | None,
    model_request_parameters: ModelRequestParameters,
    run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]

Try each model in sequence until one succeeds.

Source code in pydantic_ai_slim/pydantic_ai/models/fallback.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@asynccontextmanager
async def request_stream(
    self,
    messages: list[ModelMessage],
    model_settings: ModelSettings | None,
    model_request_parameters: ModelRequestParameters,
    run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
    """Try each model in sequence until one succeeds."""
    exceptions: list[Exception] = []

    for model in self.models:
        customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
        async with AsyncExitStack() as stack:
            try:
                response = await stack.enter_async_context(
                    model.request_stream(messages, model_settings, customized_model_request_parameters, run_context)
                )
            except Exception as exc:
                if self._fallback_on(exc):
                    exceptions.append(exc)
                    continue
                raise exc  # pragma: no cover

            self._set_span_attributes(model)
            yield response
            return

    raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)

model_name property

model_name: str

The model name.