Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# FastAPI Injector

A powerful dependency injection integration for FastAPI and Taskiq applications, built on top of the Python [injector](https://github.com/alecthomas/injector) library.
A powerful dependency injection integration for FastAPI and Taskiq applications, now featuring a self-contained dependency injection core designed for robustness and flexibility.

## Features

Expand All @@ -24,10 +24,16 @@ pip install git+https://github.com/10XScale-in/fastapi-injector.git
```python
from fastapi import FastAPI, Body
from fastapi_injector import attach_injector, Injected
from injector import Injector
from custom_injector.core import Injector # Updated import
from custom_injector.scopes import SingletonScope # Added for example
from typing import Annotated
import abc # Added for UserRepository example

# Define your interfaces and implementations
# Example User class (add for context if needed)
class User:
def __init__(self, name: str):
self.name = name
class UserRepository(abc.ABC):
@abc.abstractmethod
async def save_user(self, user: User) -> None:
Expand All @@ -41,7 +47,8 @@ class PostgresUserRepository(UserRepository):
# Create and configure your FastAPI application
app = FastAPI()
injector = Injector()
injector.binder.bind(UserRepository, to=PostgresUserRepository)
# injector.binder.bind(UserRepository, to=PostgresUserRepository) # Old way
injector.bind(UserRepository, to_class=PostgresUserRepository, scope=SingletonScope) # New way
attach_injector(app, injector)

# Use injection in your routes
Expand All @@ -57,16 +64,38 @@ async def create_user(
### Taskiq Integration

```python
from taskiq import TaskiqState, Context
from taskiq import TaskiqState, Context # Assuming TaskiqBroker is defined elsewhere
from fastapi_injector import attach_injector_taskiq, InjectedTaskiq
from custom_injector.core import Injector # Updated import
from custom_injector.scopes import SingletonScope # Added for example
import abc # Added for UserRepository example

# Example User class (add for context if needed)
class User:
def __init__(self, name: str):
self.name = name

# Define your interfaces and implementations (assuming from FastAPI example)
class UserRepository(abc.ABC):
@abc.abstractmethod
async def save_user(self, user: User) -> None:
pass

class PostgresUserRepository(UserRepository):
async def save_user(self, user: User) -> None:
# Implementation details
print(f"Saving user {user.name} to Postgres")
pass


# Initialize Taskiq broker and state
broker = TaskiqBroker()
# broker = TaskiqBroker() # Assuming broker is defined
state = TaskiqState()

# Configure injection
injector = Injector()
injector.binder.bind(UserRepository, to=PostgresUserRepository)
# injector.binder.bind(UserRepository, to=PostgresUserRepository) # Old way
injector.bind(UserRepository, to_class=PostgresUserRepository, scope=SingletonScope) # New way
attach_injector_taskiq(state, injector)

# Use injection in your tasks
Expand All @@ -84,15 +113,38 @@ async def process_user(
Enable request-scoped dependencies for better resource management:

```python
from fastapi_injector import InjectorMiddleware, request_scope, RequestScopeOptions
from fastapi_injector import InjectorMiddleware, RequestScope, RequestScopeOptions # Updated import
from custom_injector.core import Injector # Ensure Injector is imported if used in a standalone example
# Assuming app and injector are already defined as in Quick Start
# from custom_injector.scopes import SingletonScope # Not needed if RequestScope is the focus

# Example Connection classes (add for context)
class DatabaseConnection:
def query(self, sql: str):
print(f"Executing query: {sql} with {self}")
return "some_data"

class PostgresConnection(DatabaseConnection):
def __init__(self):
print(f"PostgresConnection {id(self)} created")
# Add __enter__ and __exit__ if enable_cleanup=True has effect
def __enter__(self):
print(f"PostgresConnection {id(self)} entered")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print(f"PostgresConnection {id(self)} exited")


# Configure request scope
# app = FastAPI() # Assuming app is defined
# injector = Injector() # Assuming injector is defined
options = RequestScopeOptions(enable_cleanup=True)
app.add_middleware(InjectorMiddleware, injector=injector)
attach_injector(app, injector, options)
app.add_middleware(InjectorMiddleware, injector=injector) # This should be before attach_injector if RequestScope itself is bound by attach_injector
attach_injector(app, injector, options) # attach_injector also binds RequestScope, RequestScopeFactory and RequestScopeOptions

# Bind with request scope
injector.binder.bind(DatabaseConnection, to=PostgresConnection, scope=request_scope)
# injector.binder.bind(DatabaseConnection, to=PostgresConnection, scope=request_scope) # Old way
injector.bind(DatabaseConnection, to_class=PostgresConnection, scope=RequestScope) # New way
```

## Synchronous Dependencies
Expand Down Expand Up @@ -125,8 +177,15 @@ from fastapi.testclient import TestClient

@pytest.fixture
def test_app():
from custom_injector.core import Injector # Updated import
from custom_injector.scopes import SingletonScope # Added for example
# Assuming UserRepository and MockUserRepository are defined
# class UserRepository(abc.ABC): ...
# class MockUserRepository(UserRepository): ...

injector = Injector()
injector.binder.bind(UserRepository, to=MockUserRepository)
# injector.binder.bind(UserRepository, to=MockUserRepository) # Old way
injector.bind(UserRepository, to_class=MockUserRepository, scope=SingletonScope) # New way
app = FastAPI()
attach_injector(app, injector)
return app
Expand Down
Empty file added custom_injector/__init__.py
Empty file.
97 changes: 97 additions & 0 deletions custom_injector/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any, Type, Callable, Dict, TypeVar, Optional, Set, TYPE_CHECKING
from .scopes import Scope, TransientScope, SingletonScope
from .providers import Provider, ClassProvider, FactoryProvider, ValueProvider
from .exceptions import BindingError, ResolutionError, CircularDependencyError

if TYPE_CHECKING:
pass # No specific imports needed here for core.py itself for now

T = TypeVar("T")

class Binding:
def __init__(self, key: Type[T], provider: Provider, scope: Scope):
self.key = key
self.provider = provider
self.scope = scope

class Injector:
def __init__(self):
self._bindings: Dict[Type[Any], Binding] = {}
self._instances: Dict[Type[Any], Any] = {} # Cache for scoped instances, primarily singletons
self._currently_resolving: Set[Type[Any]] = set() # For circular dependency detection
self._scopes_instances: Dict[Type[Scope], Scope] = {} # Cache for scope instances

def _get_scope_instance(self, scope_class: Type[Scope]) -> Scope:
if scope_class not in self._scopes_instances:
try:
# Try to initialize with the injector instance itself
scope_instance = scope_class(self)
except TypeError:
# If scope_class.__init__ doesn't take an argument (or not an injector)
try:
scope_instance = scope_class()
except Exception as e:
raise BindingError(f"Could not instantiate scope {scope_class.__name__}: {e}")
self._scopes_instances[scope_class] = scope_instance
return self._scopes_instances[scope_class]

def bind(
self,
key: Type[T],
*,
to_class: Optional[Type[Any]] = None,
to_factory: Optional[Callable[..., Any]] = None,
to_value: Optional[Any] = None,
scope: Type[Scope] = TransientScope # Pass the class, not an instance
):
if [to_class, to_factory, to_value].count(None) < 2:
raise BindingError(f"Provide only one of to_class, to_factory, or to_value for key {key}")

if to_class:
provider = ClassProvider(to_class)
elif to_factory:
provider = FactoryProvider(to_factory)
elif to_value is not None: # Check for `is not None` because to_value could be False or 0
provider = ValueProvider(to_value)
# Values are inherently singletons in behavior, binding to a specific scope class doesn't change the value.
# We can enforce that values are always effectively singleton by wrapping them in a specific scope if needed,
# or just let the ValueProvider return the value. For now, scope applies like others.
else:
# Default to binding the key to itself if it's a class
if isinstance(key, type):
provider = ClassProvider(key)
else:
raise BindingError(f"Cannot determine provider for key {key}. Please specify to_class, to_factory, or to_value.")

self._bindings[key] = Binding(key, provider, self._get_scope_instance(scope))

def get(self, key: Type[T]) -> T:
if key in self._currently_resolving:
raise CircularDependencyError(f"Circular dependency detected for key {key}")
self._currently_resolving.add(key)

try:
binding = self._bindings.get(key)
if not binding:
# Attempt to auto-bind if 'key' is a class and not yet bound
if isinstance(key, type):
self.bind(key, to_class=key, scope=TransientScope) # Default to Transient for auto-bindings
binding = self._bindings.get(key)
if not binding: # Should not happen after auto-bind
raise ResolutionError(f"Auto-binding failed for key {key}")
else:
raise ResolutionError(f"No binding found for key {key} and it's not a class type for auto-binding.")

# The scope's get_instance will use the provider.
# The scope's get_instance will use the provider.
# The SingletonScope needs to be more robust to use the injector's _instances cache.
# Let's refine how SingletonScope interacts with the injector cache.

# Replace the existing scope handling block with this:
return binding.scope.get_instance(
binding_key=key,
provider_callable=lambda: binding.provider.get_instance(self),
injector_cache=self._instances
)
finally:
self._currently_resolving.remove(key)
15 changes: 15 additions & 0 deletions custom_injector/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class DIError(Exception):
"""Base class for all dependency injection errors."""
pass

class BindingError(DIError):
"""Error during binding registration."""
pass

class ResolutionError(DIError):
"""Error during dependency resolution."""
pass

class CircularDependencyError(ResolutionError):
"""Circular dependency detected."""
pass
129 changes: 129 additions & 0 deletions custom_injector/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Type, Callable, TYPE_CHECKING, Dict, List, Tuple

if TYPE_CHECKING:
from .core import Injector
# from .core import T # Import T if used for generic type hints in providers
# T is not directly used in this file after the change, but Injector might need it.
# For now, keeping it commented out unless a direct need arises in this file.

class Provider(ABC):
@abstractmethod
def get_instance(self, injector: 'Injector') -> Any:
pass

class ValueProvider(Provider):
def __init__(self, value: Any):
self._value = value

def get_instance(self, injector: 'Injector') -> Any:
return self._value

def _resolve_dependencies(
injector: 'Injector',
callable_to_inspect: Callable[..., Any],
is_method: bool = False
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Helper function to inspect a callable (function or method) and resolve its dependencies.
Returns a tuple of (args, kwargs) for calling the callable.
"""
args_to_pass = []
kwargs_to_pass = {}

# Need to import ResolutionError if we decide to raise it.
# from .exceptions import ResolutionError

sig = inspect.signature(callable_to_inspect)
params = list(sig.parameters.values())

if is_method: # Skip 'self' or 'cls' for methods
# Ensure there are parameters to skip, and that the first is 'self' or 'cls' by convention.
# A more robust check might be needed for atypical method signatures.
if params and (params[0].name == 'self' or params[0].name == 'cls'):
params = params[1:]
elif inspect.isclass(callable_to_inspect): # Check if it's a class constructor itself (e.g. __init__ of a metaclass)
# This path is less common for typical DI scenarios with __init__
# but added for robustness if callable_to_inspect is a class.
# However, __init__ is usually an instance method.
pass


for param in params:
param_type = param.annotation
if param_type is inspect.Parameter.empty:
# If there's a default, Python will use it if no value is provided.
# We only inject if there's a type hint.
if param.default is inspect.Parameter.empty:
# This is a required parameter without a type hint. DI cannot fill it.
# Python will raise a TypeError if it's not provided.
# Optionally, raise ResolutionError here:
# raise ResolutionError(f"Cannot inject parameter '{param.name}' for '{callable_to_inspect.__name__}': missing type annotation and no default value.")
pass # Let Python handle it, or raise error.
continue


# Resolve the dependency using the injector
# This assumes injector.get() can handle param_type.
try:
resolved_dependency = injector.get(param_type)
except Exception as e: # Catching a broad exception to wrap it, consider more specific ones from injector
# from .exceptions import ResolutionError # Ensure this is imported
# raise ResolutionError(f"Failed to resolve dependency for parameter '{param.name}' of type {param_type} in '{callable_to_inspect.__name__}': {e}")
# For now, re-raise to see original error from injector.get
raise

if param.kind == inspect.Parameter.POSITIONAL_ONLY:
args_to_pass.append(resolved_dependency)
elif param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
# If a default value exists and DI resolves a value, DI takes precedence.
# This is standard behavior: explicit injection overrides defaults.
args_to_pass.append(resolved_dependency)
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs_to_pass[param.name] = resolved_dependency
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
# print(f"Warning: VAR_POSITIONAL parameter (*{param.name}) in {callable_to_inspect.__name__} is not supported for DI.")
pass
elif param.kind == inspect.Parameter.VAR_KEYWORD:
# print(f"Warning: VAR_KEYWORD parameter (**{param.name}) in {callable_to_inspect.__name__} is not supported for DI.")
pass

return args_to_pass, kwargs_to_pass


class ClassProvider(Provider):
def __init__(self, cls: Type[Any]):
self._cls = cls

def get_instance(self, injector: 'Injector') -> Any:
constructor = self._cls.__init__

# Check if the constructor is the default object.__init__ which takes no arguments (other than self)
# or if it's a custom __init__ method.
if constructor is object.__init__:
# No custom __init__, so no dependencies to inject for constructor
return self._cls()

args, kwargs = _resolve_dependencies(injector, constructor, is_method=True)
return self._cls(*args, **kwargs)

class FactoryProvider(Provider):
def __init__(self, factory: Callable[..., Any]):
self._factory = factory

def get_instance(self, injector: 'Injector') -> Any:
# Check if factory is a method (bound or unbound) to correctly adjust for 'self'/'cls'
is_method = inspect.ismethod(self._factory) or \
(inspect.isfunction(self._factory) and '.' in self._factory.__qualname__ and not inspect.isclass(self._factory))

# A more robust check for methods, especially for staticmethods or classmethods if they were passed directly
# For simplicity, assuming typical functions or instance methods.
# If self._factory is a bound method, 'self' is already part of its context.
# inspect.signature() handles bound methods correctly, so is_method=False might be okay
# if the 'self' is already bound. However, if it's an unbound method taken from a class,
# then is_method=True would be needed if it were called like Class.method().
# Let's stick to the provided logic for now and refine if test cases show issues.

args, kwargs = _resolve_dependencies(injector, self._factory, is_method=is_method)
return self._factory(*args, **kwargs)
Loading