"""Minimal **LCEL-style** composition: ``Runnable`` steps wired with the pipe operator ``|``.
This is a tiny re-implementation of the idea behind LangChain Expression Language—**no**
``langchain`` dependency. Enough to express linear chains, ``assign`` on dict state, and
parallel fan-out that merges back in a later step.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from typing import Any
def _as_runnable(step: Runnable | Callable[[Any], Any]) -> Runnable:
if isinstance(step, Runnable):
return step
if callable(step):
return RunnableLambda(step)
msg = "Expected a Runnable or a single-argument callable"
raise TypeError(msg)
[docs]
class Runnable(ABC):
"""One step in a pipeline: call ``invoke``; use ``a | b`` to chain."""
[docs]
@abstractmethod
def invoke(self, input: Any) -> Any:
"""Execute this step and return its output."""
def __or__(self, other: object) -> Any:
if isinstance(other, Runnable):
right = other
elif callable(other):
right = RunnableLambda(other)
else:
return NotImplemented
return RunnableSequence(self, right)
[docs]
class RunnableSequence(Runnable):
"""Result of ``left | right``: run *left*, pass output into *right*."""
def __init__(self, left: Runnable, right: Runnable) -> None:
self._left = left
self._right = right
[docs]
def invoke(self, input: Any) -> Any:
return self._right.invoke(self._left.invoke(input))
def __or__(self, other: object) -> Any:
if isinstance(other, Runnable):
right = other
elif callable(other):
right = RunnableLambda(other)
else:
return NotImplemented
return RunnableSequence(self, right)
[docs]
class RunnableLambda(Runnable):
"""Wrap any single-argument callable as a ``Runnable``."""
def __init__(self, func: Callable[[Any], Any]) -> None:
self._func = func
[docs]
def invoke(self, input: Any) -> Any:
return self._func(input)
[docs]
class RunnablePassthrough(Runnable):
"""Identity step. ``assign`` builds a step that augments mapping inputs."""
[docs]
def invoke(self, input: Any) -> Any:
return input
[docs]
@classmethod
def assign(cls, **kwargs: Callable[[Any], Any]) -> RunnableAssign:
"""Each keyword becomes a new key; values are ``fn(input)`` (input is the full mapping)."""
return RunnableAssign(kwargs)
[docs]
class RunnableAssign(Runnable):
"""``{**input, k: fn(input)}`` for each configured ``k``."""
def __init__(self, getters: Mapping[str, Callable[[Any], Any]]) -> None:
self._getters = dict(getters)
[docs]
def invoke(self, input: Any) -> Any:
if not isinstance(input, Mapping):
msg = "RunnableAssign expects a mapping input"
raise TypeError(msg)
out: dict[str, Any] = dict(input)
for key, fn in self._getters.items():
out[key] = fn(input)
return out
[docs]
class RunnableParallel(Runnable):
"""Run several runnables on the *same* input; output is ``{branch_name: branch_output}``."""
def __init__(self, **branches: Runnable | Callable[[Any], Any]) -> None:
self._branches = {name: _as_runnable(step) for name, step in branches.items()}
[docs]
def invoke(self, input: Any) -> Any:
return {name: r.invoke(input) for name, r in self._branches.items()}