stream.go

199 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
package assistant

import (
	"bufio"
	"io"
	"strings"
	"sync"
)

// EventType represents the type of streaming event.
type EventType int

// Streaming event types used by StreamReader to classify incoming SSE events.
const (
	EventContentDelta  EventType = iota // Text content chunk
	EventToolCallStart                  // Tool call started
	EventToolCallDelta                  // Tool call argument chunk
	EventDone                           // Stream complete
	EventError                          // Error occurred
)

// StreamEvent represents a single event in the stream.
type StreamEvent struct {
	Type      EventType
	Content   string    // For ContentDelta
	ToolCall  *ToolCall // For ToolCallStart/ToolCallDelta (partial)
	ToolIndex int       // Index of tool call being updated
	Error     error     // For Error events
	Usage     *Usage    // For Done events (optional)
}

// StreamReader reads events from an SSE stream.
type StreamReader struct {
	mu        sync.Mutex
	reader    *bufio.Reader
	closer    io.Closer
	content   strings.Builder
	toolCalls []ToolCall
	done      bool
	err       error

	// parseEvent is set by the provider to parse raw SSE data
	parseEvent func(data string) (*StreamEvent, error)
}

// NewStreamReader creates a stream reader from an io.ReadCloser.
// This is provider-facing: used by provider subpackages to wire SSE parsing.
func NewStreamReader(r io.ReadCloser, parseEvent func(data string) (*StreamEvent, error)) *StreamReader {
	return &StreamReader{
		reader:     bufio.NewReader(r),
		closer:     r,
		parseEvent: parseEvent,
	}
}

// Next returns the next event in the stream.
// Returns nil when the stream is complete or on error.
func (s *StreamReader) Next() *StreamEvent {
	s.mu.Lock()
	defer s.mu.Unlock()

	if s.done || s.err != nil {
		return nil
	}

	// Read SSE event
	for {
		line, err := s.reader.ReadString('\n')
		if err != nil {
			if err == io.EOF {
				s.done = true
				return &StreamEvent{Type: EventDone}
			}
			s.err = err
			return &StreamEvent{Type: EventError, Error: err}
		}

		line = strings.TrimSpace(line)

		// Skip empty lines and comments
		if line == "" || strings.HasPrefix(line, ":") {
			continue
		}

		// Parse data lines
		if strings.HasPrefix(line, "data: ") {
			data := strings.TrimPrefix(line, "data: ")

			// Check for stream end
			if data == "[DONE]" {
				s.done = true
				return &StreamEvent{Type: EventDone}
			}

			// Parse the event
			event, err := s.parseEvent(data)
			if err != nil {
				// Skip unparseable events
				continue
			}

			// Accumulate content and tool calls
			switch event.Type {
			case EventContentDelta:
				s.content.WriteString(event.Content)
			case EventToolCallStart:
				event.ToolIndex = len(s.toolCalls)
				s.toolCalls = append(s.toolCalls, *event.ToolCall)
			case EventToolCallDelta:
				// Map from provider content block index to local tool call index
				idx := len(s.toolCalls) - 1
				if idx >= 0 {
					s.toolCalls[idx].Arguments += event.ToolCall.Arguments
				}
			case EventDone:
				s.done = true
			case EventError:
				s.err = event.Error
			}

			return event
		}
	}
}

// Close closes the underlying reader.
func (s *StreamReader) Close() error {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.done = true
	if s.closer != nil {
		return s.closer.Close()
	}
	return nil
}

// Content returns all accumulated content.
func (s *StreamReader) Content() string {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.content.String()
}

// ToolCalls returns all accumulated tool calls.
func (s *StreamReader) ToolCalls() []ToolCall {
	s.mu.Lock()
	defer s.mu.Unlock()
	// Return a copy to avoid race conditions
	result := make([]ToolCall, len(s.toolCalls))
	copy(result, s.toolCalls)
	return result
}

// Done returns true if the stream is complete.
func (s *StreamReader) Done() bool {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.done
}

// Err returns any error that occurred during streaming.
func (s *StreamReader) Err() error {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.err
}

// Collect reads all events and returns the final response.
func (s *StreamReader) Collect() (*ChatResponse, error) {
	for {
		event := s.Next()
		if event == nil {
			break
		}
		if event.Type == EventError {
			return nil, event.Error
		}
		if event.Type == EventDone {
			break
		}
	}

	s.mu.Lock()
	resp := &ChatResponse{
		Content: s.content.String(),
	}
	resp.ToolCalls = make([]ToolCall, len(s.toolCalls))
	copy(resp.ToolCalls, s.toolCalls)
	err := s.err
	s.mu.Unlock()

	if len(resp.ToolCalls) > 0 {
		resp.FinishReason = "tool_calls"
	} else {
		resp.FinishReason = "stop"
	}

	return resp, err
}