stream.go
199 lines1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
package assistant
import (
"bufio"
"io"
"strings"
"sync"
)
// EventType represents the type of streaming event.
type EventType int
// Streaming event types used by StreamReader to classify incoming SSE events.
const (
EventContentDelta EventType = iota // Text content chunk
EventToolCallStart // Tool call started
EventToolCallDelta // Tool call argument chunk
EventDone // Stream complete
EventError // Error occurred
)
// StreamEvent represents a single event in the stream.
type StreamEvent struct {
Type EventType
Content string // For ContentDelta
ToolCall *ToolCall // For ToolCallStart/ToolCallDelta (partial)
ToolIndex int // Index of tool call being updated
Error error // For Error events
Usage *Usage // For Done events (optional)
}
// StreamReader reads events from an SSE stream.
type StreamReader struct {
mu sync.Mutex
reader *bufio.Reader
closer io.Closer
content strings.Builder
toolCalls []ToolCall
done bool
err error
// parseEvent is set by the provider to parse raw SSE data
parseEvent func(data string) (*StreamEvent, error)
}
// NewStreamReader creates a stream reader from an io.ReadCloser.
// This is provider-facing: used by provider subpackages to wire SSE parsing.
func NewStreamReader(r io.ReadCloser, parseEvent func(data string) (*StreamEvent, error)) *StreamReader {
return &StreamReader{
reader: bufio.NewReader(r),
closer: r,
parseEvent: parseEvent,
}
}
// Next returns the next event in the stream.
// Returns nil when the stream is complete or on error.
func (s *StreamReader) Next() *StreamEvent {
s.mu.Lock()
defer s.mu.Unlock()
if s.done || s.err != nil {
return nil
}
// Read SSE event
for {
line, err := s.reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.done = true
return &StreamEvent{Type: EventDone}
}
s.err = err
return &StreamEvent{Type: EventError, Error: err}
}
line = strings.TrimSpace(line)
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, ":") {
continue
}
// Parse data lines
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// Check for stream end
if data == "[DONE]" {
s.done = true
return &StreamEvent{Type: EventDone}
}
// Parse the event
event, err := s.parseEvent(data)
if err != nil {
// Skip unparseable events
continue
}
// Accumulate content and tool calls
switch event.Type {
case EventContentDelta:
s.content.WriteString(event.Content)
case EventToolCallStart:
event.ToolIndex = len(s.toolCalls)
s.toolCalls = append(s.toolCalls, *event.ToolCall)
case EventToolCallDelta:
// Map from provider content block index to local tool call index
idx := len(s.toolCalls) - 1
if idx >= 0 {
s.toolCalls[idx].Arguments += event.ToolCall.Arguments
}
case EventDone:
s.done = true
case EventError:
s.err = event.Error
}
return event
}
}
}
// Close closes the underlying reader.
func (s *StreamReader) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
s.done = true
if s.closer != nil {
return s.closer.Close()
}
return nil
}
// Content returns all accumulated content.
func (s *StreamReader) Content() string {
s.mu.Lock()
defer s.mu.Unlock()
return s.content.String()
}
// ToolCalls returns all accumulated tool calls.
func (s *StreamReader) ToolCalls() []ToolCall {
s.mu.Lock()
defer s.mu.Unlock()
// Return a copy to avoid race conditions
result := make([]ToolCall, len(s.toolCalls))
copy(result, s.toolCalls)
return result
}
// Done returns true if the stream is complete.
func (s *StreamReader) Done() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.done
}
// Err returns any error that occurred during streaming.
func (s *StreamReader) Err() error {
s.mu.Lock()
defer s.mu.Unlock()
return s.err
}
// Collect reads all events and returns the final response.
func (s *StreamReader) Collect() (*ChatResponse, error) {
for {
event := s.Next()
if event == nil {
break
}
if event.Type == EventError {
return nil, event.Error
}
if event.Type == EventDone {
break
}
}
s.mu.Lock()
resp := &ChatResponse{
Content: s.content.String(),
}
resp.ToolCalls = make([]ToolCall, len(s.toolCalls))
copy(resp.ToolCalls, s.toolCalls)
err := s.err
s.mu.Unlock()
if len(resp.ToolCalls) > 0 {
resp.FinishReason = "tool_calls"
} else {
resp.FinishReason = "stop"
}
return resp, err
}