Skip to content

Commit a8a1150

Browse files
committed
Update TurnDetector plugins to use the new Stream API
1 parent 7f4c41d commit a8a1150

File tree

6 files changed

+125
-119
lines changed

6 files changed

+125
-119
lines changed
Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
from .turn_detection import (
2-
TurnEvent,
3-
TurnDetector,
4-
)
5-
from .events import (
6-
TurnStartedEvent,
7-
TurnEndedEvent,
8-
)
9-
1+
from .events import TurnEndedEvent, TurnStartedEvent
2+
from .turn_detection import TurnDetector, TurnEnded, TurnStarted
103

114
__all__ = [
125
# Base classes and types
13-
"TurnEvent",
146
"TurnDetector",
157
# Events
168
"TurnStartedEvent",
179
"TurnEndedEvent",
10+
"TurnEnded",
11+
"TurnStarted",
1812
]

agents-core/vision_agents/core/turn_detection/turn_detection.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
1-
from typing import Optional
2-
from abc import ABC, abstractmethod
3-
from enum import Enum
41
import uuid
2+
from abc import ABC, abstractmethod
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Optional
5+
56
from getstream.video.rtc.track_util import PcmData
67
from vision_agents.core.events.manager import EventManager
7-
from . import events
8-
from .events import TurnStartedEvent, TurnEndedEvent
9-
from ..agents.conversation import Conversation
8+
109
from ..edge.types import Participant
10+
from ..utils.stream import Stream
11+
from . import events
12+
13+
if TYPE_CHECKING:
14+
from vision_agents.core.agents.conversation import Conversation
1115

1216

13-
class TurnEvent(Enum):
14-
"""Events that can occur during turn detection (deprecated - use TurnStartedEvent/TurnEndedEvent)."""
17+
@dataclass
18+
class TurnStarted:
19+
"""
20+
Event emitted when a speaker starts their turn.
21+
"""
1522

16-
TURN_STARTED = "turn_started"
17-
TURN_ENDED = "turn_ended"
23+
participant: Participant
24+
confidence: float
25+
26+
27+
@dataclass
28+
class TurnEnded:
29+
participant: Participant
30+
confidence: float
31+
eager: bool = False
32+
trailing_silence_ms: Optional[float] = None
33+
duration_ms: Optional[float] = None
1834

1935

2036
class TurnDetector(ABC):
@@ -29,44 +45,24 @@ def __init__(
2945
self.provider_name = provider_name or self.__class__.__name__
3046
self.events = EventManager()
3147
self.events.register_events_from_module(events, ignore_not_compatible=True)
48+
self._output: Stream[TurnEnded | TurnStarted] = Stream()
3249

33-
def _emit_start_turn_event(self, event: TurnStartedEvent) -> None:
34-
event.session_id = self.session_id
35-
event.plugin_name = self.provider_name
36-
self.events.send(event)
37-
38-
def _emit_end_turn_event(
39-
self,
40-
participant: Participant,
41-
confidence: Optional[float] = None,
42-
trailing_silence_ms: Optional[float] = None,
43-
duration_ms: Optional[float] = None,
44-
eager_end_of_turn: bool = False,
45-
) -> None:
46-
if confidence is None:
47-
confidence = 0.5
48-
event = TurnEndedEvent(
49-
session_id=self.session_id,
50-
plugin_name=self.provider_name,
51-
participant=participant,
52-
confidence=confidence,
53-
trailing_silence_ms=trailing_silence_ms,
54-
duration_ms=duration_ms,
55-
eager_end_of_turn=eager_end_of_turn,
56-
)
57-
self.events.send(event)
50+
@property
51+
def output(self) -> Stream[TurnEnded | TurnStarted]:
52+
"""Pipeline output stream: consumers iterate, subclasses push via send_nowait."""
53+
return self._output
5854

5955
@abstractmethod
6056
async def process_audio(
6157
self,
62-
audio_data: PcmData,
58+
data: PcmData,
6359
participant: Participant,
64-
conversation: Optional[Conversation],
60+
conversation: "Conversation | None" = None,
6561
) -> None:
6662
"""Process the audio and trigger turn start or turn end events
6763
6864
Args:
69-
audio_data: PcmData object containing audio samples from Stream
65+
data: PcmData object containing audio samples from Stream
7066
participant: Participant that's speaking, includes user data
7167
conversation: Transcription/ chat history, sometimes useful for turn detection
7268
"""

plugins/smart_turn/tests/test_smart_turn.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from vision_agents.core.agents.conversation import InMemoryConversation
55
from vision_agents.core.edge.types import Participant
6-
from vision_agents.core.turn_detection import TurnEndedEvent, TurnStartedEvent
6+
from vision_agents.core.turn_detection import TurnEnded, TurnStarted
77
from vision_agents.core.vad.silero import SileroVADSessionPool
88
from vision_agents.plugins.smart_turn.smart_turn_detection import SmartTurnDetection
99

@@ -34,26 +34,21 @@ async def test_turn_detection_chunks(self, smart_turn, mia_audio_16khz):
3434
participant = Participant(user_id="mia", id="mia", original={})
3535
conversation = InMemoryConversation(instructions="be nice", messages=[])
3636

37-
event_order = []
38-
39-
# Subscribe to events
40-
@smart_turn.events.subscribe
41-
async def on_start(event: TurnStartedEvent):
42-
logger.info(f"Smart turn turn started on {event.session_id}")
43-
event_order.append("start")
44-
45-
@smart_turn.events.subscribe
46-
async def on_stop(event: TurnEndedEvent):
47-
logger.info(f"Smart turn turn ended on {event.session_id}")
48-
event_order.append("stop")
49-
5037
for pcm in mia_audio_16khz.chunks(chunk_size=304):
5138
await smart_turn.process_audio(pcm, participant, conversation)
5239

53-
# Wait for background processing to complete
5440
await smart_turn.wait_for_processing_complete()
5541

56-
assert event_order == ["start", "stop"] or event_order == [
42+
items = await smart_turn.output.collect(timeout=1.0)
43+
kinds = [
44+
"start"
45+
if isinstance(item, TurnStarted)
46+
else "stop"
47+
if isinstance(item, TurnEnded)
48+
else None
49+
for item in items
50+
]
51+
assert kinds == ["start", "stop"] or kinds == [
5752
"start",
5853
"stop",
5954
"start",
@@ -63,39 +58,51 @@ async def on_stop(event: TurnEndedEvent):
6358
async def test_turn_detection(self, smart_turn, mia_audio_16khz):
6459
participant = Participant(user_id="mia", id="mia", original={})
6560
conversation = InMemoryConversation(instructions="be nice", messages=[])
66-
event_order = []
67-
68-
# Subscribe to events
69-
@smart_turn.events.subscribe
70-
async def on_start(event: TurnStartedEvent):
71-
logger.info(f"Smart turn turn started on {event.session_id}")
72-
event_order.append("start")
73-
74-
@smart_turn.events.subscribe
75-
async def on_stop(event: TurnEndedEvent):
76-
logger.info(f"Smart turn turn ended on {event.session_id}")
77-
event_order.append("stop")
7861

7962
await smart_turn.process_audio(mia_audio_16khz, participant, conversation)
8063

81-
# Wait for background processing to complete
8264
await smart_turn.wait_for_processing_complete()
8365

84-
# Verify that turn detection is working - we should get at least some turn events
66+
items = await smart_turn.output.collect(timeout=1.0)
67+
kinds = [
68+
"start"
69+
if isinstance(item, TurnStarted)
70+
else "stop"
71+
if isinstance(item, TurnEnded)
72+
else None
73+
for item in items
74+
]
8575
# With continuous processing, we may get multiple start/stop cycles
86-
assert event_order == ["start", "stop"] or event_order == [
76+
assert kinds == ["start", "stop"] or kinds == [
8777
"start",
8878
"stop",
8979
"start",
9080
"stop",
9181
]
9282

83+
async def test_silence_does_not_start_segment(self, smart_turn, silence_1s_16khz):
84+
participant = Participant(user_id="mia", id="mia", original={})
85+
conversation = InMemoryConversation(instructions="be nice", messages=[])
86+
87+
await smart_turn.process_audio(silence_1s_16khz, participant, conversation)
88+
await smart_turn.wait_for_processing_complete()
89+
90+
items = await smart_turn.output.collect(timeout=0.5)
91+
assert items == []
92+
93+
async def test_speech_starts_segment(self, smart_turn, mia_audio_16khz):
94+
participant = Participant(user_id="mia", id="mia", original={})
95+
conversation = InMemoryConversation(instructions="be nice", messages=[])
96+
97+
await smart_turn.process_audio(mia_audio_16khz, participant, conversation)
98+
await smart_turn.wait_for_processing_complete()
99+
100+
items = await smart_turn.output.collect(timeout=1.0)
101+
assert any(isinstance(item, TurnStarted) for item in items)
102+
93103
"""
94104
TODO
95105
- Test that the 2nd turn detect includes the audio from the first turn
96106
- Test that turn detection is ran after 8s of audio
97107
- Test that turn detection is run after speech and 2s of silence
98-
- Test that silence doens't start a new segmetn
99-
- Test that speaking starts a new segment
100-
101108
"""

plugins/smart_turn/vision_agents/plugins/smart_turn/smart_turn_detection.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from vision_agents.core.edge.types import Participant
1515
from vision_agents.core.turn_detection import (
1616
TurnDetector,
17-
TurnStartedEvent,
17+
TurnEnded,
18+
TurnStarted,
1819
)
1920
from vision_agents.core.utils.utils import ensure_model
2021
from vision_agents.core.vad.silero import SileroVADSession, SileroVADSessionPool
@@ -162,7 +163,7 @@ async def process_audio(
162163
self,
163164
audio_data: PcmData,
164165
participant: Participant,
165-
conversation: Optional[Conversation],
166+
conversation: Conversation | None = None,
166167
) -> None:
167168
"""
168169
Fast, non-blocking audio packet enqueueing.
@@ -289,11 +290,13 @@ async def _process_audio_packet(
289290
prediction = await self._predict_turn_completed(merged, participant)
290291
turn_ended = prediction > 0.5
291292
if turn_ended:
292-
self._emit_end_turn_event(
293-
participant=participant,
294-
confidence=prediction,
295-
trailing_silence_ms=trailing_silence_ms,
296-
duration_ms=self._active_segment.duration_ms,
293+
await self.output.send(
294+
TurnEnded(
295+
participant=participant,
296+
confidence=prediction,
297+
trailing_silence_ms=trailing_silence_ms,
298+
duration_ms=self._active_segment.duration_ms,
299+
)
297300
)
298301
self._active_segment = None
299302
self._silence = Silence()
@@ -304,7 +307,12 @@ async def _process_audio_packet(
304307
self._pre_speech_buffer.append(merged)
305308
self._pre_speech_buffer = self._pre_speech_buffer.tail(8)
306309
elif is_speech and self._active_segment is None:
307-
self._emit_start_turn_event(TurnStartedEvent(participant=participant))
310+
await self.output.send(
311+
TurnStarted(
312+
participant=participant,
313+
confidence=speech_probability,
314+
)
315+
)
308316
# create a new segment
309317
self._active_segment = PcmData(
310318
sample_rate=RATE, channels=1, format=AudioFormat.F32

plugins/vogent/tests/test_vogent.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import asyncio
21
import logging
32

43
import pytest
54
from vision_agents.core.agents.conversation import InMemoryConversation
65
from vision_agents.core.edge.types import Participant
7-
from vision_agents.core.turn_detection import TurnEndedEvent, TurnStartedEvent
6+
from vision_agents.core.turn_detection import TurnEnded, TurnStarted
87
from vision_agents.plugins.vogent.vogent_turn_detection import VogentTurnDetection
98

109
logger = logging.getLogger(__name__)
@@ -30,18 +29,6 @@ async def test_turn_detection(
3029
):
3130
participant = Participant(user_id="mia", original={}, id="mia")
3231
conversation = InMemoryConversation(instructions="be nice", messages=[])
33-
event_order = []
34-
35-
# Subscribe to events
36-
@vogent_turn_detection.events.subscribe
37-
async def on_start(event: TurnStartedEvent):
38-
logger.info(f"Vogent turn started on {event.session_id}")
39-
event_order.append("start")
40-
41-
@vogent_turn_detection.events.subscribe
42-
async def on_stop(event: TurnEndedEvent):
43-
logger.info(f"Vogent turn ended on {event.session_id}")
44-
event_order.append("stop")
4532

4633
await vogent_turn_detection.process_audio(
4734
mia_audio_16khz, participant, conversation
@@ -50,12 +37,18 @@ async def on_stop(event: TurnEndedEvent):
5037
silence_2s_48khz, participant, conversation
5138
)
5239

53-
await asyncio.sleep(0.001)
40+
await vogent_turn_detection.wait_for_processing_complete()
5441

55-
await asyncio.sleep(5)
56-
57-
# Verify that turn detection is working - we should get at least some turn events
58-
assert event_order == ["start", "stop"] or event_order == [
42+
items = await vogent_turn_detection.output.collect(timeout=1.0)
43+
kinds = [
44+
"start"
45+
if isinstance(item, TurnStarted)
46+
else "stop"
47+
if isinstance(item, TurnEnded)
48+
else None
49+
for item in items
50+
]
51+
assert kinds == ["start", "stop"] or kinds == [
5952
"start",
6053
"stop",
6154
"start",

0 commit comments

Comments
 (0)