diff --git a/cmd/lk/agent.go b/cmd/lk/agent.go index 1cdf4eb5..471e95e2 100644 --- a/cmd/lk/agent.go +++ b/cmd/lk/agent.go @@ -346,6 +346,7 @@ var ( ArgsUsage: "[working-dir]", }, privateLinkCommands, + simulateCommand, }, }, } diff --git a/cmd/lk/agent_private_link.go b/cmd/lk/agent_private_link.go index 34043fd3..36b783d1 100644 --- a/cmd/lk/agent_private_link.go +++ b/cmd/lk/agent_private_link.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "strconv" "github.com/livekit/livekit-cli/v2/pkg/util" lkproto "github.com/livekit/protocol/livekit" @@ -87,11 +86,9 @@ var privateLinkCommands = &cli.Command{ }, } -func buildCreatePrivateLinkRequest(name, region string, port uint32, awsEndpoint string) *lkproto.CreatePrivateLinkRequest { +func buildCreatePrivateLinkRequest(name, awsEndpoint string) *lkproto.CreatePrivateLinkRequest { return &lkproto.CreatePrivateLinkRequest{ - Name: name, - Region: region, - Port: port, + Name: name, Config: &lkproto.CreatePrivateLinkRequest_Aws{ Aws: &lkproto.CreatePrivateLinkRequest_AWSCreateConfig{ Endpoint: awsEndpoint, @@ -104,14 +101,14 @@ func privateLinkServiceDNS(name, projectID string) string { return fmt.Sprintf("%s-%s.plg.svc", name, projectID) } -func buildPrivateLinkListRows(links []*lkproto.PrivateLink, healthByID map[string]*lkproto.PrivateLinkStatus, healthErrByID map[string]error) [][]string { +func buildPrivateLinkListRows(links []*lkproto.PrivateLink, healthByID map[string]*lkproto.PrivateLinkHealthStatus, healthErrByID map[string]error) [][]string { var rows [][]string for _, link := range links { if link == nil { continue } - status := lkproto.PrivateLinkStatus_PRIVATE_LINK_STATUS_UNKNOWN.String() + status := lkproto.PrivateLinkHealthStatus_PRIVATE_LINK_ATTACHMENT_HEALTH_STATUS_UNKNOWN.String() updatedAt := "-" if err, ok := healthErrByID[link.PrivateLinkId]; ok && err != nil { @@ -127,8 +124,6 @@ func buildPrivateLinkListRows(links []*lkproto.PrivateLink, healthByID map[strin rows = append(rows, []string{ link.PrivateLinkId, link.Name, - link.Region, - strconv.FormatUint(uint64(link.Port), 10), status, updatedAt, }) @@ -144,7 +139,7 @@ func formatPrivateLinkClientError(action string, err error) error { } func createPrivateLink(ctx context.Context, cmd *cli.Command) error { - req := buildCreatePrivateLinkRequest(cmd.String("name"), cmd.String("region"), uint32(cmd.Uint("port")), cmd.String("endpoint")) + req := buildCreatePrivateLinkRequest(cmd.String("name"), cmd.String("endpoint")) resp, err := agentsClient.CreatePrivateLink(ctx, req) if err != nil { return formatPrivateLinkClientError("create", err) @@ -173,13 +168,13 @@ func listPrivateLinks(ctx context.Context, cmd *cli.Command) error { return formatPrivateLinkClientError("list", err) } - healthByID := make(map[string]*lkproto.PrivateLinkStatus, len(resp.Items)) + healthByID := make(map[string]*lkproto.PrivateLinkHealthStatus, len(resp.Items)) healthErrByID := make(map[string]error) for _, link := range resp.Items { if link == nil || link.PrivateLinkId == "" { continue } - health, healthErr := agentsClient.GetPrivateLinkStatus(ctx, &lkproto.GetPrivateLinkStatusRequest{ + health, healthErr := agentsClient.GetPrivateLinkHealthStatus(ctx, &lkproto.GetPrivateLinkHealthStatusRequest{ PrivateLinkId: link.PrivateLinkId, }) if healthErr != nil { @@ -193,9 +188,9 @@ func listPrivateLinks(ctx context.Context, cmd *cli.Command) error { if cmd.Bool("json") { type privateLinkWithHealth struct { - PrivateLink *lkproto.PrivateLink `json:"private_link"` - Status *lkproto.PrivateLinkStatus `json:"health"` - HealthError string `json:"health_error,omitempty"` + PrivateLink *lkproto.PrivateLink `json:"private_link"` + Status *lkproto.PrivateLinkHealthStatus `json:"health"` + HealthError string `json:"health_error,omitempty"` } items := make([]privateLinkWithHealth, 0, len(resp.Items)) for _, link := range resp.Items { @@ -221,7 +216,7 @@ func listPrivateLinks(ctx context.Context, cmd *cli.Command) error { } rows := buildPrivateLinkListRows(resp.Items, healthByID, healthErrByID) - table := util.CreateTable().Headers("ID", "Name", "Region", "Port", "Health", "Updated At").Rows(rows...) + table := util.CreateTable().Headers("ID", "Name", "Health", "Updated At").Rows(rows...) fmt.Println(table) return nil } @@ -245,7 +240,7 @@ func deletePrivateLink(ctx context.Context, cmd *cli.Command) error { func getPrivateLinkHealthStatus(ctx context.Context, cmd *cli.Command) error { privateLinkID := cmd.String("id") - resp, err := agentsClient.GetPrivateLinkStatus(ctx, &lkproto.GetPrivateLinkStatusRequest{ + resp, err := agentsClient.GetPrivateLinkHealthStatus(ctx, &lkproto.GetPrivateLinkHealthStatusRequest{ PrivateLinkId: privateLinkID, }) if err != nil { diff --git a/cmd/lk/agent_private_link_test.go b/cmd/lk/agent_private_link_test.go index 9b9fb9dd..16ac78a0 100644 --- a/cmd/lk/agent_private_link_test.go +++ b/cmd/lk/agent_private_link_test.go @@ -45,12 +45,10 @@ func TestAgentPrivateLinkCommandTree(t *testing.T) { } func TestBuildCreatePrivateLinkRequest_HappyPath(t *testing.T) { - req := buildCreatePrivateLinkRequest("orders-db", "us-east-1", 6379, "com.amazonaws.vpce.us-east-1.vpce-svc-abc123") + req := buildCreatePrivateLinkRequest("orders-db", "com.amazonaws.vpce.us-east-1.vpce-svc-abc123") require.NotNil(t, req) assert.Equal(t, "orders-db", req.Name) - assert.Equal(t, "us-east-1", req.Region) - assert.Equal(t, uint32(6379), req.Port) aws := req.GetAws() require.NotNil(t, aws) @@ -62,7 +60,7 @@ func TestPrivateLinkServiceDNS(t *testing.T) { } func TestBuildPrivateLinkListRows_EmptyList(t *testing.T) { - rows := buildPrivateLinkListRows([]*lkproto.PrivateLink{}, map[string]*lkproto.PrivateLinkStatus{}, map[string]error{}) + rows := buildPrivateLinkListRows([]*lkproto.PrivateLink{}, map[string]*lkproto.PrivateLinkHealthStatus{}, map[string]error{}) assert.Empty(t, rows) } @@ -71,15 +69,13 @@ func TestBuildPrivateLinkListRows_OnePrivateLink(t *testing.T) { { PrivateLinkId: "pl-1", Name: "orders-db", - Region: "us-east-1", - Port: 6379, }, } now := time.Now().UTC() - healthByID := map[string]*lkproto.PrivateLinkStatus{ + healthByID := map[string]*lkproto.PrivateLinkHealthStatus{ "pl-1": { - Status: lkproto.PrivateLinkStatus_PRIVATE_LINK_STATUS_AVAILABLE, + Status: lkproto.PrivateLinkHealthStatus_PRIVATE_LINK_ATTACHMENT_HEALTH_STATUS_HEALTHY, UpdatedAt: timestamppb.New(now), }, } @@ -88,41 +84,33 @@ func TestBuildPrivateLinkListRows_OnePrivateLink(t *testing.T) { require.Len(t, rows, 1) assert.Equal(t, "pl-1", rows[0][0]) assert.Equal(t, "orders-db", rows[0][1]) - assert.Equal(t, "us-east-1", rows[0][2]) - assert.Equal(t, "6379", rows[0][3]) - assert.Equal(t, lkproto.PrivateLinkStatus_PRIVATE_LINK_STATUS_AVAILABLE.String(), rows[0][4]) + assert.Equal(t, lkproto.PrivateLinkHealthStatus_PRIVATE_LINK_ATTACHMENT_HEALTH_STATUS_HEALTHY.String(), rows[0][2]) } -func TestBuildPrivateLinkListRows_TwoPrivateLinksDifferentRegions(t *testing.T) { +func TestBuildPrivateLinkListRows_TwoPrivateLinks(t *testing.T) { links := []*lkproto.PrivateLink{ { PrivateLinkId: "pl-1", Name: "orders-db", - Region: "us-east-1", - Port: 6379, }, { PrivateLinkId: "pl-2", Name: "cache", - Region: "eu-west-1", - Port: 6380, }, } - healthByID := map[string]*lkproto.PrivateLinkStatus{ + healthByID := map[string]*lkproto.PrivateLinkHealthStatus{ "pl-1": { - Status: lkproto.PrivateLinkStatus_PRIVATE_LINK_STATUS_AVAILABLE, + Status: lkproto.PrivateLinkHealthStatus_PRIVATE_LINK_ATTACHMENT_HEALTH_STATUS_HEALTHY, }, "pl-2": { - Status: lkproto.PrivateLinkStatus_PRIVATE_LINK_STATUS_AVAILABLE, + Status: lkproto.PrivateLinkHealthStatus_PRIVATE_LINK_ATTACHMENT_HEALTH_STATUS_HEALTHY, }, } rows := buildPrivateLinkListRows(links, healthByID, map[string]error{}) require.Len(t, rows, 2) - assert.Equal(t, "us-east-1", rows[0][2]) - assert.Equal(t, "eu-west-1", rows[1][2]) - assert.Equal(t, lkproto.PrivateLinkStatus_PRIVATE_LINK_STATUS_AVAILABLE.String(), rows[0][4]) - assert.Equal(t, lkproto.PrivateLinkStatus_PRIVATE_LINK_STATUS_AVAILABLE.String(), rows[1][4]) + assert.Equal(t, lkproto.PrivateLinkHealthStatus_PRIVATE_LINK_ATTACHMENT_HEALTH_STATUS_HEALTHY.String(), rows[0][2]) + assert.Equal(t, lkproto.PrivateLinkHealthStatus_PRIVATE_LINK_ATTACHMENT_HEALTH_STATUS_HEALTHY.String(), rows[1][2]) } diff --git a/cmd/lk/console.go b/cmd/lk/console.go index 2dab4130..9c22657b 100644 --- a/cmd/lk/console.go +++ b/cmd/lk/console.go @@ -19,48 +19,59 @@ package main import ( "context" "fmt" + "io" + "log" + "net" "os" "strings" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/urfave/cli/v3" + "github.com/livekit/livekit-cli/v2/pkg/agentfs" "github.com/livekit/livekit-cli/v2/pkg/console" "github.com/livekit/livekit-cli/v2/pkg/portaudio" ) -var ConsoleCommands = []*cli.Command{ - { - Name: "console", - Usage: "Voice chat with an agent via mic/speakers", - Category: "Core", - Flags: []cli.Flag{ - &cli.IntFlag{ - Name: "port", - Aliases: []string{"p"}, - Usage: "TCP port for agent communication", - Value: 0, - }, - &cli.StringFlag{ - Name: "input-device", - Usage: "Input device index or name substring", - }, - &cli.StringFlag{ - Name: "output-device", - Usage: "Output device index or name substring", - }, - &cli.BoolFlag{ - Name: "list-devices", - Usage: "List available audio devices and exit", - }, - &cli.BoolFlag{ - Name: "no-aec", - Usage: "Disable acoustic echo cancellation", - }, +func init() { + AgentCommands[0].Commands = append(AgentCommands[0].Commands, consoleCommand) +} + +var consoleCommand = &cli.Command{ + Name: "console", + Usage: "Voice chat with an agent via mic/speakers", + Category: "Core", + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "port", + Aliases: []string{"p"}, + Usage: "TCP port for agent communication", + Value: 0, + }, + &cli.StringFlag{ + Name: "input-device", + Usage: "Input device index or name substring", + }, + &cli.StringFlag{ + Name: "output-device", + Usage: "Output device index or name substring", + }, + &cli.BoolFlag{ + Name: "list-devices", + Usage: "List available audio devices and exit", + }, + &cli.BoolFlag{ + Name: "no-aec", + Usage: "Disable acoustic echo cancellation", + }, + &cli.StringFlag{ + Name: "entrypoint", + Usage: "Agent entrypoint `FILE` (default: auto-detect)", }, - Action: runConsole, }, + Action: runConsole, } func runConsole(ctx context.Context, cmd *cli.Command) error { @@ -103,19 +114,75 @@ func runConsole(ctx context.Context, cmd *cli.Command) error { defer server.Close() actualAddr := server.Addr().String() - fmt.Fprintf(os.Stderr, "Listening on %s\n", actualAddr) fmt.Fprintf(os.Stderr, "Input: %s\n", inputDev.Name) fmt.Fprintf(os.Stderr, "Output: %s\n", outputDev.Name) - fmt.Fprintf(os.Stderr, "Waiting for agent connection...\n") - conn, err := server.Accept() + // Detect project type, walking up parent directories if needed. + projectDir, projectType, err := agentfs.DetectProjectRoot(".") + if err != nil { + return err + } + if !projectType.IsPython() { + return fmt.Errorf("console currently only supports Python agents (detected: %s)", projectType) + } + + // Resolve entrypoint relative to project root + entrypoint, err := findEntrypoint(projectDir, cmd.String("entrypoint"), projectType) + if err != nil { + return err + } + + fmt.Fprintf(os.Stderr, "Starting agent (%s in %s)...\n", entrypoint, projectDir) + agentProc, err := startAgent(AgentStartConfig{ + Dir: projectDir, + Entrypoint: entrypoint, + ProjectType: projectType, + CLIArgs: []string{"console", "--connect-addr", actualAddr}, + }) if err != nil { - return fmt.Errorf("agent connection: %w", err) + return fmt.Errorf("failed to start agent: %w", err) } - defer conn.Close() + defer agentProc.Kill() + + // Stream agent logs to the TUI + agentProc.LogStream = make(chan string, 128) - fmt.Fprintf(os.Stderr, "Agent connected from %s\n", conn.RemoteAddr()) + // Wait for TCP connection, agent crash, timeout, or cancellation + type acceptResult struct { + conn net.Conn + err error + } + acceptCh := make(chan acceptResult, 1) + go func() { + conn, err := server.Accept() + acceptCh <- acceptResult{conn, err} + }() + var conn net.Conn + select { + case res := <-acceptCh: + if res.err != nil { + return fmt.Errorf("agent connection: %w", res.err) + } + conn = res.conn + case err := <-agentProc.Done(): + logs := agentProc.RecentLogs(20) + for _, l := range logs { + fmt.Fprintln(os.Stderr, l) + } + if err != nil { + return fmt.Errorf("agent exited before connecting: %w", err) + } + return fmt.Errorf("agent exited before connecting") + case <-time.After(60 * time.Second): + logs := agentProc.RecentLogs(20) + for _, l := range logs { + fmt.Fprintln(os.Stderr, l) + } + return fmt.Errorf("timed out waiting for agent to connect") + case <-ctx.Done(): + return ctx.Err() + } pipeline, err := console.NewPipeline(console.PipelineConfig{ InputDevice: inputDev, OutputDevice: outputDev, @@ -133,8 +200,11 @@ func runConsole(ctx context.Context, cmd *cli.Command) error { pipeline.Start(pipelineCtx) }() - model := newConsoleModel(pipeline, actualAddr, inputDev, outputDev) - p := tea.NewProgram(model, tea.WithAltScreen()) + // Redirect Go's default logger to discard so it doesn't corrupt the TUI + log.SetOutput(io.Discard) + + model := newConsoleModel(pipeline, agentProc, inputDev.Name, outputDev.Name) + p := tea.NewProgram(model) if _, err := p.Run(); err != nil { return err @@ -184,3 +254,4 @@ func listDevices() error { return nil } + diff --git a/cmd/lk/console_stub.go b/cmd/lk/console_stub.go index 6cc487ba..917095cb 100644 --- a/cmd/lk/console_stub.go +++ b/cmd/lk/console_stub.go @@ -2,8 +2,45 @@ package main -import "github.com/urfave/cli/v3" +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" -// ConsoleCommands is nil when built without the console tag. -// This ensures the default build (CGO_ENABLED=0) is unaffected. -var ConsoleCommands []*cli.Command + "github.com/urfave/cli/v3" +) + +func init() { + AgentCommands[0].Commands = append(AgentCommands[0].Commands, &cli.Command{ + Name: "console", + Usage: "Voice chat with an agent via mic/speakers", + Action: func(ctx context.Context, cmd *cli.Command) error { + msg := "console is not included in this build.\n\n" + if isHomebrewInstall() { + msg += "\"brew install livekit-cli\" does not include console support.\n" + + "Install with console support:\n" + + " brew tap livekit/livekit && brew install lk\n" + } else { + msg += "Install with console support:\n" + + " https://docs.livekit.io/intro/basics/cli/start/\n" + } + msg += "\nOr build from source:\n" + + " go build -tags console ./cmd/lk" + return fmt.Errorf("%s", msg) + }, + }) +} + +func isHomebrewInstall() bool { + exe, err := os.Executable() + if err != nil { + return false + } + resolved, err := filepath.EvalSymlinks(exe) + if err != nil { + return false + } + return strings.Contains(resolved, "/Cellar/") +} diff --git a/cmd/lk/console_tui.go b/cmd/lk/console_tui.go index 1c468bbd..adc9da98 100644 --- a/cmd/lk/console_tui.go +++ b/cmd/lk/console_tui.go @@ -17,153 +17,523 @@ package main import ( + "encoding/json" "fmt" "strings" "time" + "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + agent "github.com/livekit/protocol/livekit/agent" + "github.com/livekit/livekit-cli/v2/pkg/console" - "github.com/livekit/livekit-cli/v2/pkg/portaudio" ) +// Console-specific styles (tagStyle, greenStyle, redStyle, dimStyle, boldStyle, cyanStyle +// are inherited from simulate_tui.go which is always compiled) var ( - consoleTitleStyle = lipgloss.NewStyle().Background(lipgloss.Color("#1fd5f9")).Foreground(lipgloss.Color("#000000")).Bold(true).Padding(0, 1) - consoleGreenStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("2")) - consoleRedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1")) - consoleDimStyle = lipgloss.NewStyle().Faint(true) - consoleBoldStyle = lipgloss.NewStyle().Bold(true) - consoleCyanStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("6")) - consoleYellowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("3")) + lkCyan = lipgloss.Color("#1fd5f9") + lkPurple = lipgloss.Color("#8f83ff") + lkGreen = lipgloss.Color("#6BCB77") + lkRed = lipgloss.Color("#EF4444") + + labelStyle = lipgloss.NewStyle().Foreground(lkPurple) + cyanBoldStyle = lipgloss.NewStyle().Foreground(lkCyan).Bold(true) + greenBoldStyle = lipgloss.NewStyle().Foreground(lkGreen).Bold(true) + redBoldStyle = lipgloss.NewStyle().Foreground(lkRed).Bold(true) ) -// Unicode block characters for frequency visualizer -var blocks = []rune{' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'} +// Unicode block characters for frequency visualizer (matching Python console) +var blocks = []string{"▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"} type consoleTickMsg struct{} +type sessionEventMsg struct{ event *agent.SessionEvent } +type agentLogMsg struct{ line string } type consoleModel struct { pipeline *console.AudioPipeline - addr string - inputDev *portaudio.DeviceInfo - outputDev *portaudio.DeviceInfo + agentProc *AgentProcess + inputDev string + outputDev string + + width int + + // Partial user transcription (not yet final) + partialTranscript string + + // Text mode + textMode bool + textInput textinput.Model + + // Shortcut help toggle (? key) + showShortcuts bool + + // Last turn metrics text (cleared on next thinking state) + metricsText string - width int - height int + // Request counter for unique IDs + reqCounter int } -func newConsoleModel(pipeline *console.AudioPipeline, addr string, inputDev, outputDev *portaudio.DeviceInfo) consoleModel { +func newConsoleModel(pipeline *console.AudioPipeline, agentProc *AgentProcess, inputDev, outputDev string) consoleModel { + ti := textinput.New() + ti.Placeholder = "Type to talk to your agent" + ti.CharLimit = 1000 + ti.Width = 60 + ti.Prompt = "❯ " + ti.PromptStyle = boldStyle + return consoleModel{ pipeline: pipeline, - addr: addr, + agentProc: agentProc, inputDev: inputDev, outputDev: outputDev, + textInput: ti, } } func (m consoleModel) Init() tea.Cmd { - return tea.Batch( + cmds := []tea.Cmd{ consoleTickCmd(), - ) + pollEventsCmd(m.pipeline), + } + if m.agentProc != nil && m.agentProc.LogStream != nil { + cmds = append(cmds, pollLogsCmd(m.agentProc.LogStream)) + } + return tea.Batch(cmds...) } func consoleTickCmd() tea.Cmd { - return tea.Tick(50*time.Millisecond, func(t time.Time) tea.Msg { + return tea.Tick(80*time.Millisecond, func(t time.Time) tea.Msg { return consoleTickMsg{} }) } +func pollEventsCmd(pipeline *console.AudioPipeline) tea.Cmd { + return func() tea.Msg { + ev, ok := <-pipeline.Events + if !ok { + return nil + } + return sessionEventMsg{event: ev} + } +} + +func pollLogsCmd(ch chan string) tea.Cmd { + return func() tea.Msg { + line, ok := <-ch + if !ok { + return nil + } + return agentLogMsg{line: line} + } +} + func (m consoleModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.KeyMsg: + if m.textMode { + return m.updateTextMode(msg) + } switch msg.String() { case "q", "ctrl+c": return m, tea.Quit case "m": m.pipeline.SetMuted(!m.pipeline.Muted()) + case "ctrl+t": + m.textMode = true + m.showShortcuts = false + m.textInput.Focus() + return m, textinput.Blink + case "?": + m.showShortcuts = !m.showShortcuts + case "esc": + m.showShortcuts = false } case tea.WindowSizeMsg: m.width = msg.Width - m.height = msg.Height case consoleTickMsg: return m, consoleTickCmd() + + case sessionEventMsg: + cmds := m.handleSessionEvent(msg.event) + cmds = append(cmds, pollEventsCmd(m.pipeline)) + return m, tea.Batch(cmds...) + + case agentLogMsg: + cmd := tea.Println(dimStyle.Render(msg.line)) + var nextCmd tea.Cmd + if m.agentProc != nil && m.agentProc.LogStream != nil { + nextCmd = pollLogsCmd(m.agentProc.LogStream) + } + return m, tea.Batch(cmd, nextCmd) } return m, nil } +func (m *consoleModel) updateTextMode(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "ctrl+c": + return m, tea.Quit + case "ctrl+t": + m.textMode = false + m.showShortcuts = false + m.textInput.Blur() + return m, nil + case "esc": + if m.showShortcuts { + m.showShortcuts = false + return m, nil + } + m.textMode = false + m.textInput.Blur() + return m, nil + case "?": + if m.textInput.Value() == "" { + m.showShortcuts = !m.showShortcuts + return m, nil + } + case "enter": + text := strings.TrimSpace(m.textInput.Value()) + if text != "" { + m.reqCounter++ + reqID := fmt.Sprintf("console-%d", m.reqCounter) + m.textInput.SetValue("") + + // Print user message matching the old console format: + // ● You + // text here + printCmd := tea.Println( + "\n " + lipgloss.NewStyle().Foreground(lkCyan).Render("● ") + + cyanBoldStyle.Render("You") + + "\n " + text, + ) + + req := &agent.SessionRequest{ + RequestId: reqID, + Request: &agent.SessionRequest_SendMessage{ + SendMessage: &agent.SendMessageRequest{Text: text}, + }, + } + go m.pipeline.SendRequest(req) + return m, printCmd + } + return m, nil + } + + var cmd tea.Cmd + m.textInput, cmd = m.textInput.Update(msg) + return m, cmd +} + +func (m *consoleModel) handleSessionEvent(ev *agent.SessionEvent) []tea.Cmd { + if ev == nil { + return nil + } + var cmds []tea.Cmd + + switch e := ev.Event.(type) { + case *agent.SessionEvent_AgentStateChanged: + if e.AgentStateChanged.NewState == agent.AgentState_AGENT_STATE_THINKING { + m.metricsText = "" + } + + case *agent.SessionEvent_UserInputTranscribed: + if e.UserInputTranscribed.IsFinal { + m.partialTranscript = "" + if text := e.UserInputTranscribed.Transcript; text != "" { + cmds = append(cmds, tea.Println( + "\n "+lipgloss.NewStyle().Foreground(lkCyan).Render("● ")+ + cyanBoldStyle.Render("You")+ + "\n "+text, + )) + } + } else { + m.partialTranscript = e.UserInputTranscribed.Transcript + } + + case *agent.SessionEvent_ConversationItemAdded: + if item := e.ConversationItemAdded.Item; item != nil { + // Extract metrics from ChatMessage (matching Python console pattern) + if msg := item.GetMessage(); msg != nil { + if text := formatMetrics(msg.Metrics); text != "" { + m.metricsText = text + } + } + lines := formatChatItem(item) + for _, line := range lines { + cmds = append(cmds, tea.Println(line)) + } + } + + case *agent.SessionEvent_FunctionToolsExecuted: + for _, fc := range e.FunctionToolsExecuted.FunctionCalls { + cmds = append(cmds, tea.Println( + " "+lipgloss.NewStyle().Foreground(lkCyan).Render("➜ ")+ + cyanBoldStyle.Render(fc.Name), + )) + } + for _, fco := range e.FunctionToolsExecuted.FunctionCallOutputs { + if fco.IsError { + cmds = append(cmds, tea.Println( + " "+redBoldStyle.Render("✗ ")+redStyle.Render(truncateOutput(fco.Output)), + )) + } else { + cmds = append(cmds, tea.Println( + " "+greenStyle.Render("✓ ")+dimStyle.Render(summarizeOutput(fco.Output)), + )) + } + } + + case *agent.SessionEvent_Error: + cmds = append(cmds, tea.Println( + " "+redBoldStyle.Render("✗ ")+redStyle.Render(e.Error.Message), + )) + } + + return cmds +} + +// formatChatItem returns lines to print for a conversation item, +// matching the old Python console format. +func formatChatItem(item *agent.ChatContext_ChatItem) []string { + switch i := item.Item.(type) { + case *agent.ChatContext_ChatItem_Message: + msg := i.Message + // User messages are printed from UserInputTranscribed (final) to avoid + // ordering issues with partial transcripts. + if msg.Role == agent.ChatRole_USER { + return nil + } + var textParts []string + for _, c := range msg.Content { + if t := c.GetText(); t != "" { + textParts = append(textParts, t) + } + } + text := strings.Join(textParts, "") + if text == "" { + return nil + } + + var lines []string + lines = append(lines, + "\n "+lipgloss.NewStyle().Foreground(lkGreen).Render("● ")+ + greenBoldStyle.Render("Agent"), + ) + for _, tl := range strings.Split(text, "\n") { + lines = append(lines, " "+tl) + } + return lines + + case *agent.ChatContext_ChatItem_FunctionCall: + return []string{ + " " + lipgloss.NewStyle().Foreground(lkCyan).Render("➜ ") + + cyanBoldStyle.Render(i.FunctionCall.Name), + } + + case *agent.ChatContext_ChatItem_FunctionCallOutput: + if i.FunctionCallOutput.IsError { + return []string{ + " " + redBoldStyle.Render("✗ ") + redStyle.Render(truncateOutput(i.FunctionCallOutput.Output)), + } + } + return []string{ + " " + greenStyle.Render("✓ ") + dimStyle.Render(summarizeOutput(i.FunctionCallOutput.Output)), + } + } + return nil +} + +// ────────────────────────────────────────────────────────────────── +// View — compact status area at the bottom (not fullscreen). +// Logs and conversation scroll up via tea.Println. +// Layout matches the old Python console (FrequencyVisualizer + prompt). +// ────────────────────────────────────────────────────────────────── + func (m consoleModel) View() string { var b strings.Builder - b.WriteString(consoleTitleStyle.Render(" lk console ")) - b.WriteString("\n\n") + if m.textMode { + // ── Text input (matching old Python prompt layout) ── + w := m.width + if w <= 0 { + w = 80 + } + sep := dimStyle.Render(strings.Repeat("─", min(w, 80))) + b.WriteString(sep) + b.WriteString("\n") + b.WriteString(m.textInput.View()) + b.WriteString("\n") + b.WriteString(sep) - b.WriteString(consoleBoldStyle.Render("Status: ")) - b.WriteString(consoleGreenStyle.Render("● Connected")) - b.WriteString(" ") - b.WriteString(consoleDimStyle.Render(m.addr)) - b.WriteString("\n") - - b.WriteString(consoleBoldStyle.Render("Input: ")) - b.WriteString(m.inputDev.Name) - b.WriteString("\n") - b.WriteString(consoleBoldStyle.Render("Output: ")) - b.WriteString(m.outputDev.Name) - b.WriteString("\n\n") - - bands := m.pipeline.FFTBands() - b.WriteString(consoleBoldStyle.Render("Audio ")) - for _, band := range bands { - idx := int(band * float64(len(blocks)-1)) - if idx >= len(blocks) { - idx = len(blocks) - 1 - } - if idx < 0 { - idx = 0 - } - if band > 0.5 { - b.WriteString(consoleCyanStyle.Render(string(blocks[idx]))) - } else if band > 0.2 { - b.WriteString(consoleGreenStyle.Render(string(blocks[idx]))) + if m.showShortcuts { + b.WriteString("\n") + m.writeShortcutsInline(&b, []shortcut{ + {"Ctrl+T", "audio mode"}, + {"Ctrl+C", "exit"}, + }) } else { - b.WriteString(consoleDimStyle.Render(string(blocks[idx]))) + b.WriteString("\n") + b.WriteString(dimStyle.Render(" ? for shortcuts")) + } + } else { + // ── Audio visualizer (matching old Python FrequencyVisualizer) ── + b.WriteString(" ") + b.WriteString(labelStyle.Render(m.inputDev)) + b.WriteString(" ") + bands := m.pipeline.FFTBands() + for _, band := range bands { + idx := int(band * float64(len(blocks)-1)) + if idx >= len(blocks) { + idx = len(blocks) - 1 + } + if idx < 0 { + idx = 0 + } + b.WriteString(" ") + b.WriteString(blocks[idx]) + } + + if m.pipeline.Muted() { + b.WriteString(" ") + b.WriteString(redBoldStyle.Render("MUTED")) + } + + // Partial transcription on same line (dim) + if m.partialTranscript != "" { + b.WriteString(" ") + b.WriteString(dimStyle.Render("● " + m.partialTranscript + "...")) + } + + // ERLE > 6dB means the AEC is actively cancelling echo — show as a + // reassuring status indicator, not a warning. + if m.pipeline.IsPlaying() { + if stats := m.pipeline.AECStats(); stats != nil && stats.HasERLE && stats.EchoReturnLossEnhancement > 2 { + b.WriteString(" ") + b.WriteString(dimStyle.Render("echo cancelling")) + } + } + + // Metrics on same line (right side) + if m.metricsText != "" { + b.WriteString(" ") + b.WriteString(m.metricsText) + } + + if m.showShortcuts { + b.WriteString("\n") + m.writeShortcutsInline(&b, []shortcut{ + {"m", "mute/unmute"}, + {"Ctrl+T", "text mode"}, + {"q", "quit"}, + }) + } else { + b.WriteString("\n") + b.WriteString(dimStyle.Render(" ? for shortcuts")) } } - b.WriteString("\n\n") - level := m.pipeline.Level() - b.WriteString(consoleBoldStyle.Render("Mic: ")) + return b.String() +} - if m.pipeline.Muted() { - b.WriteString(consoleRedStyle.Render("MUTED")) - } else { - // Level bar: -60dB to 0dB - normalized := (level + 60) / 60 - if normalized < 0 { - normalized = 0 - } - if normalized > 1 { - normalized = 1 - } - barWidth := 30 - filled := int(normalized * float64(barWidth)) - bar := strings.Repeat("█", filled) + strings.Repeat("░", barWidth-filled) - if normalized > 0.8 { - b.WriteString(consoleRedStyle.Render(bar)) - } else if normalized > 0.5 { - b.WriteString(consoleYellowStyle.Render(bar)) +type shortcut struct { + key string + desc string +} + +func (m consoleModel) writeShortcutsInline(b *strings.Builder, shortcuts []shortcut) { + dimBoldStyle := lipgloss.NewStyle().Faint(true).Bold(true) + b.WriteString(" ") + for i, s := range shortcuts { + if i > 0 { + b.WriteString(dimStyle.Render(" · ")) + } + b.WriteString(dimBoldStyle.Render(s.key)) + b.WriteString(" ") + b.WriteString(dimStyle.Render(s.desc)) + } +} + +// formatMetrics formats a MetricsReport matching the Python console display. +func formatMetrics(m *agent.MetricsReport) string { + if m == nil { + return "" + } + + var parts []string + sep := dimStyle.Render(" · ") + + if m.LlmNodeTtft != nil { + parts = append(parts, dimStyle.Render("llm_ttft ")+dimStyle.Render(formatMs(*m.LlmNodeTtft))) + } + if m.TtsNodeTtfb != nil { + parts = append(parts, dimStyle.Render("tts_ttfb ")+dimStyle.Render(formatMs(*m.TtsNodeTtfb))) + } + if m.E2ELatency != nil { + label := "e2e " + formatMs(*m.E2ELatency) + if *m.E2ELatency >= 1.0 { + parts = append(parts, redStyle.Render(label)) } else { - b.WriteString(consoleGreenStyle.Render(bar)) + parts = append(parts, dimStyle.Render(label)) } - b.WriteString(fmt.Sprintf(" %.0f dB", level)) } - b.WriteString("\n\n") - b.WriteString(consoleDimStyle.Render("m: mute/unmute q: quit")) - b.WriteString("\n") + if len(parts) == 0 { + return "" + } + return strings.Join(parts, sep) +} - return b.String() +func formatMs(seconds float64) string { + ms := seconds * 1000 + if ms >= 100 { + return fmt.Sprintf("%.0fms", ms) + } + return fmt.Sprintf("%.1fms", ms) +} + +// summarizeOutput tries to parse JSON and produce a "key=value, key=value" summary +// matching the old Python console behavior. Falls back to truncation. +func summarizeOutput(output string) string { + jsonStart := strings.Index(output, "{") + if jsonStart < 0 { + return truncateOutput(output) + } + + var data map[string]any + if err := json.Unmarshal([]byte(output[jsonStart:]), &data); err != nil { + return truncateOutput(output) + } + + var parts []string + for k, v := range data { + if v == nil || k == "type" { + continue + } + parts = append(parts, fmt.Sprintf("%s=%v", k, v)) + if len(parts) >= 3 { + break + } + } + result := strings.Join(parts, ", ") + if len(data) > 3 { + result += ", ..." + } + if result == "" { + return truncateOutput(output) + } + return result +} + +func truncateOutput(output string) string { + if len(output) > 200 { + return output[:197] + "..." + } + return output } diff --git a/cmd/lk/main.go b/cmd/lk/main.go index d12796f1..faf899ec 100644 --- a/cmd/lk/main.go +++ b/cmd/lk/main.go @@ -70,7 +70,6 @@ func main() { app.Commands = append(app.Commands, SIPCommands...) app.Commands = append(app.Commands, PhoneNumberCommands...) app.Commands = append(app.Commands, ReplayCommands...) - app.Commands = append(app.Commands, ConsoleCommands...) app.Commands = append(app.Commands, PerfCommands...) // Register cleanup hook for SIGINT, SIGTERM, SIGQUIT diff --git a/cmd/lk/simulate.go b/cmd/lk/simulate.go new file mode 100644 index 00000000..566eee59 --- /dev/null +++ b/cmd/lk/simulate.go @@ -0,0 +1,246 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "os" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/urfave/cli/v3" + + "github.com/livekit/livekit-cli/v2/pkg/agentfs" + "github.com/livekit/livekit-cli/v2/pkg/config" + "github.com/livekit/protocol/livekit" + lksdk "github.com/livekit/server-sdk-go/v2" +) + +var ( + simulateProjectConfig *config.ProjectConfig +) + +var simulateCommand = &cli.Command{ + Name: "simulate", + Usage: "Run agent simulations against LiveKit Cloud", + Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) { + pc, err := loadProjectDetails(cmd) + if err != nil { + return nil, err + } + simulateProjectConfig = pc + return nil, nil + }, + Action: runSimulate, + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "num-simulations", + Aliases: []string{"n"}, + Usage: "Number of scenarios to generate", + Value: 5, + }, + &cli.StringFlag{ + Name: "description", + Usage: "Agent description for scenario generation", + }, + &cli.StringFlag{ + Name: "scenario-group-id", + Usage: "Use a pre-configured scenario group", + }, + &cli.StringFlag{ + Name: "config", + Usage: "Path to simulation config `FILE` (default: simulation.json)", + }, + &cli.StringFlag{ + Name: "entrypoint", + Usage: "Agent entrypoint `FILE` (default: agent.py)", + }, + &cli.StringFlag{ + Name: "cloud-url", + Value: cloudAPIServerURL, + Hidden: true, + DefaultText: cloudAPIServerURL, + }, + }, +} + +// simulationConfig represents the simulation.json config file. +type simulationConfig struct { + AgentDescription string `json:"agent_description"` + Scenarios []scenarioConfig `json:"scenarios"` +} + +type scenarioConfig struct { + Label string `json:"label"` + Instructions string `json:"instructions"` + AgentExpectations string `json:"agent_expectations"` + Metadata map[string]string `json:"metadata"` +} + +func loadSimulationConfig(path string) (*simulationConfig, error) { + if path == "" { + path = "simulation.json" + } + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) && path == "simulation.json" { + return &simulationConfig{}, nil + } + return nil, fmt.Errorf("failed to read config: %w", err) + } + var cfg simulationConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + return &cfg, nil +} + +func generateAgentName() string { + const chars = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, 8) + for i := range b { + b[i] = chars[rand.Intn(len(chars))] + } + return "simulation-" + string(b) +} + +func runSimulate(ctx context.Context, cmd *cli.Command) error { + pc := simulateProjectConfig + + // Load simulation config + configPath := cmd.String("config") + cfg, err := loadSimulationConfig(configPath) + if err != nil { + return err + } + + // Resolve description + description := cmd.String("description") + if description == "" { + description = cfg.AgentDescription + } + + numSimulations := int32(cmd.Int("num-simulations")) + scenarioGroupID := cmd.String("scenario-group-id") + agentName := generateAgentName() + + // Detect project type, walking up parent directories if needed + projectDir, projectType, err := agentfs.DetectProjectRoot(".") + if err != nil { + return err + } + if !projectType.IsPython() { + return fmt.Errorf("simulate currently only supports Python agents (detected: %s)", projectType) + } + + // Resolve entrypoint + entrypoint, err := findEntrypoint(projectDir, cmd.String("entrypoint"), projectType) + if err != nil { + return err + } + + // Launch agent subprocess + agent, err := startAgent(AgentStartConfig{ + Dir: projectDir, + Entrypoint: entrypoint, + ProjectType: projectType, + CLIArgs: []string{"dev", "--no-reload"}, + Env: []string{ + "LIVEKIT_AGENT_NAME=" + agentName, + "LIVEKIT_URL=" + pc.URL, + "LIVEKIT_API_KEY=" + pc.APIKey, + "LIVEKIT_API_SECRET=" + pc.APISecret, + }, + ReadySignal: "registered worker", + }) + if err != nil { + return err + } + defer agent.Kill() + + // Create API client + cloudURL := cmd.String("cloud-url") + simClient := lksdk.NewAgentSimulationClient(cloudURL, pc.APIKey, pc.APISecret) + + // Build the create request + req := &livekit.CreateSimulationRunRequest{ + AgentName: agentName, + AgentDescription: description, + } + if len(cfg.Scenarios) > 0 { + scenarios := make([]*livekit.CreateSimulationRunRequest_Scenario, 0, len(cfg.Scenarios)) + for _, sc := range cfg.Scenarios { + scenarios = append(scenarios, &livekit.CreateSimulationRunRequest_Scenario{ + Label: sc.Label, + Instructions: sc.Instructions, + AgentExpectations: sc.AgentExpectations, + Metadata: sc.Metadata, + }) + } + req.Source = &livekit.CreateSimulationRunRequest_Scenarios{ + Scenarios: &livekit.CreateSimulationRunRequest_ScenarioList{ + Scenarios: scenarios, + }, + } + } else if scenarioGroupID != "" { + req.Source = &livekit.CreateSimulationRunRequest_GroupId{ + GroupId: scenarioGroupID, + } + } else { + req.Source = &livekit.CreateSimulationRunRequest_NumSimulations{ + NumSimulations: numSimulations, + } + } + + // Wait for worker registration or subprocess exit + fmt.Println("Starting agent...") + select { + case <-agent.Ready(): + // Worker registered + case err := <-agent.Done(): + logs := agent.RecentLogs(20) + for _, l := range logs { + fmt.Fprintln(os.Stderr, l) + } + if err != nil { + return fmt.Errorf("agent exited before registering: %w", err) + } + return fmt.Errorf("agent exited before registering") + case <-time.After(60 * time.Second): + return fmt.Errorf("timed out waiting for agent to register") + case <-ctx.Done(): + return ctx.Err() + } + + // Create the simulation run + fmt.Println("Creating simulation run...") + resp, err := simClient.CreateSimulationRun(ctx, req) + if err != nil { + return fmt.Errorf("failed to create simulation run: %w", err) + } + runID := resp.SimulationRunId + + // Run the TUI + model := newSimulateModel(simClient, runID, numSimulations, agent) + p := tea.NewProgram(model, tea.WithAltScreen()) + if _, err := p.Run(); err != nil { + return fmt.Errorf("TUI error: %w", err) + } + + return nil +} diff --git a/cmd/lk/simulate_subprocess.go b/cmd/lk/simulate_subprocess.go new file mode 100644 index 00000000..fcf92bed --- /dev/null +++ b/cmd/lk/simulate_subprocess.go @@ -0,0 +1,239 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bufio" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/livekit/livekit-cli/v2/pkg/agentfs" +) + +// AgentProcess manages a Python agent subprocess. +type AgentProcess struct { + cmd *exec.Cmd + readyCh chan struct{} + doneCh chan error + + // LogStream receives log lines in real-time. Nil if not needed. + LogStream chan string + + mu sync.Mutex + logLines []string + maxLogs int +} + +// findPythonBinary locates a Python binary for the given project type. +func findPythonBinary(dir string, projectType agentfs.ProjectType) (string, []string, error) { + if projectType == agentfs.ProjectTypePythonUV { + uvPath, err := exec.LookPath("uv") + if err == nil { + return uvPath, []string{"run", "python"}, nil + } + } + + // Check common venv locations + for _, venvDir := range []string{".venv", "venv"} { + candidate := filepath.Join(dir, venvDir, "bin", "python") + if _, err := os.Stat(candidate); err == nil { + return candidate, nil, nil + } + } + + // Fall back to system python + pythonPath, err := exec.LookPath("python3") + if err != nil { + pythonPath, err = exec.LookPath("python") + if err != nil { + return "", nil, fmt.Errorf("could not find Python binary; ensure a virtual environment exists or Python is on PATH") + } + } + return pythonPath, nil, nil +} + +// findEntrypoint resolves the agent entrypoint file. +func findEntrypoint(dir, explicit string, projectType agentfs.ProjectType) (string, error) { + if explicit != "" { + path := explicit + if !filepath.IsAbs(path) { + path = filepath.Join(dir, path) + } + if _, err := os.Stat(path); err != nil { + return "", fmt.Errorf("entrypoint not found: %s", explicit) + } + return explicit, nil + } + def := projectType.DefaultEntrypoint() + if def == "" { + def = "agent.py" + } + + // Check project root first + if _, err := os.Stat(filepath.Join(dir, def)); err == nil { + return def, nil + } + + // Fall back to cwd-relative path (e.g. running from examples/drive-thru/) + cwd, _ := os.Getwd() + if rel, err := filepath.Rel(dir, cwd); err == nil && rel != "." { + candidate := filepath.Join(rel, def) + if _, err := os.Stat(filepath.Join(dir, candidate)); err == nil { + return candidate, nil + } + } + + return "", fmt.Errorf("entrypoint not found: %s (use --entrypoint to specify)", def) +} + +// AgentStartConfig configures how to launch an agent subprocess. +type AgentStartConfig struct { + Dir string + Entrypoint string + ProjectType agentfs.ProjectType + CLIArgs []string // e.g. ["dev", "--no-reload"] or ["console", "--connect-addr", addr] + Env []string // e.g. ["LIVEKIT_AGENT_NAME=x"] or nil + ReadySignal string // substring to scan for in output (e.g. "registered worker"), empty to skip +} + +// startAgent launches a Python agent subprocess and monitors its output. +func startAgent(cfg AgentStartConfig) (*AgentProcess, error) { + pythonBin, prefixArgs, err := findPythonBinary(cfg.Dir, cfg.ProjectType) + if err != nil { + return nil, err + } + + args := append(prefixArgs, cfg.Entrypoint) + args = append(args, cfg.CLIArgs...) + cmd := exec.Command(pythonBin, args...) + cmd.Dir = cfg.Dir + if len(cfg.Env) > 0 { + cmd.Env = append(os.Environ(), cfg.Env...) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stderr pipe: %w", err) + } + + ap := &AgentProcess{ + cmd: cmd, + readyCh: make(chan struct{}), + doneCh: make(chan error, 1), + maxLogs: 200, + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start agent: %w", err) + } + + // Capture output from both stdout and stderr + readyOnce := sync.Once{} + scanOutput := func(r io.Reader) { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + ap.appendLog(line) + if cfg.ReadySignal != "" && strings.Contains(line, cfg.ReadySignal) { + readyOnce.Do(func() { close(ap.readyCh) }) + } + } + } + + // If no ready signal, mark ready immediately + if cfg.ReadySignal == "" { + close(ap.readyCh) + } + + go scanOutput(stdout) + go scanOutput(stderr) + go func() { + ap.doneCh <- cmd.Wait() + }() + + return ap, nil +} + +func (ap *AgentProcess) appendLog(line string) { + ap.mu.Lock() + defer ap.mu.Unlock() + ap.logLines = append(ap.logLines, line) + if len(ap.logLines) > ap.maxLogs { + ap.logLines = ap.logLines[len(ap.logLines)-ap.maxLogs:] + } + // Stream to TUI if channel exists + if ap.LogStream != nil { + select { + case ap.LogStream <- line: + default: + } + } +} + +// Ready returns a channel that is closed when the agent worker has registered. +func (ap *AgentProcess) Ready() <-chan struct{} { + return ap.readyCh +} + +// Done returns a channel that receives the process exit error. +func (ap *AgentProcess) Done() <-chan error { + return ap.doneCh +} + +// RecentLogs returns the last n log lines from the subprocess. +func (ap *AgentProcess) RecentLogs(n int) []string { + ap.mu.Lock() + defer ap.mu.Unlock() + if n >= len(ap.logLines) { + result := make([]string, len(ap.logLines)) + copy(result, ap.logLines) + return result + } + result := make([]string, n) + copy(result, ap.logLines[len(ap.logLines)-n:]) + return result +} + +// LogCount returns the total number of log lines captured. +func (ap *AgentProcess) LogCount() int { + ap.mu.Lock() + defer ap.mu.Unlock() + return len(ap.logLines) +} + +// Kill sends SIGINT to the subprocess and SIGKILL after a timeout. +func (ap *AgentProcess) Kill() { + if ap.cmd.Process == nil { + return + } + _ = ap.cmd.Process.Signal(syscall.SIGINT) + select { + case <-ap.doneCh: + case <-time.After(5 * time.Second): + _ = ap.cmd.Process.Kill() + } +} diff --git a/cmd/lk/simulate_tui.go b/cmd/lk/simulate_tui.go new file mode 100644 index 00000000..1044f852 --- /dev/null +++ b/cmd/lk/simulate_tui.go @@ -0,0 +1,625 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + + "github.com/livekit/protocol/livekit" + lksdk "github.com/livekit/server-sdk-go/v2" +) + +// --- Styles --- + +var ( + tagStyle = lipgloss.NewStyle().Background(lipgloss.Color("#1fd5f9")).Foreground(lipgloss.Color("#000000")).Bold(true).Padding(0, 1) + greenStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("2")) + redStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1")) + yellowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("3")) + dimStyle = lipgloss.NewStyle().Faint(true) + boldStyle = lipgloss.NewStyle().Bold(true) + reverseStyle = lipgloss.NewStyle().Reverse(true) + cyanStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("6")).Bold(true) +) + +// --- Message types --- + +type simulationRunMsg struct { + run *livekit.SimulationRun + err error +} + +type pollTickMsg struct{} + +type subprocessExitMsg struct { + err error +} + +// --- Filter --- + +const ( + filterAll = iota + filterFailed + filterPassed + filterRunning +) + +var filterNames = []string{"All", "Failed", "Passed", "Running"} + +// --- Model --- + +type simulateModel struct { + client *lksdk.AgentSimulationClient + runID string + numSimulations int32 + agent *AgentProcess + + run *livekit.SimulationRun + runFinished bool + startTime time.Time + + filter int + cursor int + scrollOff int + detailJobID string + showLogs bool + + width int + height int + err error +} + +func newSimulateModel(client *lksdk.AgentSimulationClient, runID string, numSimulations int32, agent *AgentProcess) *simulateModel { + return &simulateModel{ + client: client, + runID: runID, + numSimulations: numSimulations, + agent: agent, + width: 80, + height: 24, + } +} + +func (m *simulateModel) Init() tea.Cmd { + return tea.Batch( + m.pollSimulation(), + m.waitSubprocess(), + tickCmd(), + ) +} + +func tickCmd() tea.Cmd { + return tea.Tick(time.Second, func(t time.Time) tea.Msg { + return pollTickMsg{} + }) +} + +func (m *simulateModel) pollSimulation() tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + resp, err := m.client.GetSimulationRun(ctx, &livekit.GetSimulationRunRequest{ + SimulationRunId: m.runID, + }) + if err != nil { + return simulationRunMsg{err: err} + } + return simulationRunMsg{run: resp.Run} + } +} + +func (m *simulateModel) waitSubprocess() tea.Cmd { + return func() tea.Msg { + err := <-m.agent.Done() + return subprocessExitMsg{err: err} + } +} + +func (m *simulateModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + + case simulationRunMsg: + if msg.err == nil && msg.run != nil { + m.run = msg.run + if m.startTime.IsZero() && msg.run.Status == livekit.SimulationRun_STATUS_RUNNING { + m.startTime = time.Now() + } + if msg.run.Status == livekit.SimulationRun_STATUS_COMPLETED || + msg.run.Status == livekit.SimulationRun_STATUS_FAILED { + m.runFinished = true + } + } + + case pollTickMsg: + var cmds []tea.Cmd + if !m.runFinished { + cmds = append(cmds, m.pollSimulation()) + } + cmds = append(cmds, tickCmd()) + return m, tea.Batch(cmds...) + + case subprocessExitMsg: + // Subprocess exited — don't quit TUI, just note it + + case tea.KeyMsg: + return m.handleKey(msg) + } + return m, nil +} + +func (m *simulateModel) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "ctrl+c": + return m, tea.Quit + case "ctrl+l": + m.showLogs = !m.showLogs + case "up", "shift+tab": + m.cursor-- + case "down", "tab": + m.cursor++ + case "pgup": + m.cursor -= 20 + case "pgdown": + m.cursor += 20 + case "left": + m.filter = (m.filter + len(filterNames) - 1) % len(filterNames) + m.cursor = 0 + m.scrollOff = 0 + case "right": + m.filter = (m.filter + 1) % len(filterNames) + m.cursor = 0 + m.scrollOff = 0 + case "enter": + if m.detailJobID == "" { + jobs := m.filteredJobs() + if m.cursor >= 0 && m.cursor < len(jobs) { + m.detailJobID = jobs[m.cursor].job.Id + } + } + case "esc", "backspace": + if m.detailJobID != "" { + m.detailJobID = "" + } + case "q": + if m.detailJobID != "" { + m.detailJobID = "" + } else { + return m, tea.Quit + } + } + return m, nil +} + +type indexedJob struct { + origIdx int + job *livekit.SimulationRun_Job +} + +func (m *simulateModel) filteredJobs() []indexedJob { + if m.run == nil { + return nil + } + var result []indexedJob + for i, j := range m.run.Jobs { + match := false + switch m.filter { + case filterAll: + match = true + case filterFailed: + match = j.Status == livekit.SimulationRun_Job_STATUS_FAILED + case filterPassed: + match = j.Status == livekit.SimulationRun_Job_STATUS_COMPLETED + case filterRunning: + match = j.Status == livekit.SimulationRun_Job_STATUS_RUNNING + } + if match { + result = append(result, indexedJob{origIdx: i + 1, job: j}) + } + } + return result +} + +func (m *simulateModel) View() string { + if m.run == nil { + return m.viewWaiting() + } + switch m.run.Status { + case livekit.SimulationRun_STATUS_GENERATING: + return m.viewGenerating() + default: + return m.viewRunning() + } +} + +func (m *simulateModel) viewWaiting() string { + var b strings.Builder + b.WriteString("\n") + b.WriteString(tagStyle.Render("Simulate")) + b.WriteString(" ") + b.WriteString(cyanStyle.Render(m.runID)) + b.WriteString("\n\n") + b.WriteString(" [1/3] Starting...\n") + if m.showLogs { + b.WriteString(m.renderLogs()) + } + b.WriteString(dimStyle.Render(" Ctrl+L logs")) + b.WriteString("\n") + return b.String() +} + +func (m *simulateModel) viewGenerating() string { + var b strings.Builder + b.WriteString("\n") + b.WriteString(tagStyle.Render("Simulate")) + b.WriteString(" ") + b.WriteString(cyanStyle.Render(m.runID)) + b.WriteString("\n\n") + b.WriteString(fmt.Sprintf(" [1/3] Generating %d scenarios...\n", m.numSimulations)) + if m.showLogs { + b.WriteString(m.renderLogs()) + } + b.WriteString(dimStyle.Render(" Ctrl+L logs")) + b.WriteString("\n") + return b.String() +} + +func (m *simulateModel) viewRunning() string { + var b strings.Builder + + b.WriteString("\n") + b.WriteString(tagStyle.Render("Simulate")) + b.WriteString(" ") + b.WriteString(cyanStyle.Render(m.runID)) + b.WriteString("\n\n") + + // Header line + b.WriteString(m.renderHeader()) + b.WriteString("\n") + + // Progress counts + b.WriteString(m.renderCounts()) + b.WriteString("\n") + + // Filter tabs + b.WriteString(m.renderFilterTabs()) + b.WriteString("\n\n") + + if m.detailJobID != "" { + b.WriteString(m.renderDetail()) + } else { + b.WriteString(m.renderJobList()) + } + + b.WriteString("\n") + if m.showLogs { + b.WriteString(m.renderLogs()) + } + b.WriteString(m.renderHint()) + b.WriteString("\n") + return b.String() +} + +func (m *simulateModel) renderHeader() string { + var step, label, style string + switch { + case m.run.Status == livekit.SimulationRun_STATUS_COMPLETED || m.run.Status == livekit.SimulationRun_STATUS_FAILED: + step = "[3/3]" + _, _, failed, _ := m.jobCounts() + if m.run.Status == livekit.SimulationRun_STATUS_FAILED { + label = "Failed" + style = "red" + } else if failed > 0 { + label = "Completed with failures" + style = "yellow" + } else { + label = "Completed" + style = "green" + } + default: + step = "[2/3]" + label = "Running" + style = "yellow" + } + + header := dimStyle.Render(step) + " " + boldStyle.Render("Simulation") + " — " + switch style { + case "green": + header += greenStyle.Bold(true).Render(label) + case "red": + header += redStyle.Bold(true).Render(label) + case "yellow": + header += yellowStyle.Bold(true).Render(label) + } + return " " + header +} + +func (m *simulateModel) jobCounts() (total, done, passed, failed int) { + if m.run == nil { + return + } + total = len(m.run.Jobs) + for _, j := range m.run.Jobs { + switch j.Status { + case livekit.SimulationRun_Job_STATUS_COMPLETED: + done++ + passed++ + case livekit.SimulationRun_Job_STATUS_FAILED: + done++ + failed++ + } + } + return +} + +func (m *simulateModel) renderCounts() string { + total, done, passed, failed := m.jobCounts() + running := 0 + if m.run != nil { + for _, j := range m.run.Jobs { + if j.Status == livekit.SimulationRun_Job_STATUS_RUNNING { + running++ + } + } + } + + var parts []string + parts = append(parts, boldStyle.Render(fmt.Sprintf("%d/%d", done, total))) + if passed > 0 { + parts = append(parts, greenStyle.Render(fmt.Sprintf("%d passed", passed))) + } + if failed > 0 { + parts = append(parts, redStyle.Render(fmt.Sprintf("%d failed", failed))) + } + if running > 0 { + parts = append(parts, yellowStyle.Render(fmt.Sprintf("%d running", running))) + } + + elapsed := "" + if !m.startTime.IsZero() { + d := time.Since(m.startTime) + secs := int(d.Seconds()) + mins := secs / 60 + secs = secs % 60 + if mins > 0 { + elapsed = fmt.Sprintf("%dm%02ds", mins, secs) + } else { + elapsed = fmt.Sprintf("%ds", secs) + } + } + + result := " " + strings.Join(parts, " ") + if elapsed != "" { + result += " " + dimStyle.Render(elapsed) + } + return result +} + +func (m *simulateModel) renderFilterTabs() string { + total, _, passed, failed := m.jobCounts() + running := 0 + if m.run != nil { + for _, j := range m.run.Jobs { + if j.Status == livekit.SimulationRun_Job_STATUS_RUNNING { + running++ + } + } + } + + counts := []int{total, failed, passed, running} + styles := []lipgloss.Style{lipgloss.NewStyle(), redStyle, greenStyle, yellowStyle} + + var parts []string + for i, name := range filterNames { + label := fmt.Sprintf("%s: %d", name, counts[i]) + if i == m.filter { + parts = append(parts, styles[i].Bold(true).Render(label)) + } else { + parts = append(parts, dimStyle.Render(label)) + } + } + return " " + strings.Join(parts, " ") +} + +func (m *simulateModel) renderJobList() string { + jobs := m.filteredJobs() + if len(jobs) == 0 { + return dimStyle.Render(" (no jobs match this filter)") + } + + // Clamp cursor + if m.cursor < 0 { + m.cursor = 0 + } + if m.cursor >= len(jobs) { + m.cursor = len(jobs) - 1 + } + + // Compute visible window + availHeight := m.height - 14 + if availHeight < 5 { + availHeight = 5 + } + + if m.cursor < m.scrollOff { + m.scrollOff = m.cursor + } else if m.cursor >= m.scrollOff+availHeight { + m.scrollOff = m.cursor - availHeight + 1 + } + if m.scrollOff < 0 { + m.scrollOff = 0 + } + if m.scrollOff > len(jobs)-availHeight { + m.scrollOff = len(jobs) - availHeight + } + if m.scrollOff < 0 { + m.scrollOff = 0 + } + + winStart := m.scrollOff + winEnd := m.scrollOff + availHeight + if winEnd > len(jobs) { + winEnd = len(jobs) + } + + var b strings.Builder + + if winStart > 0 { + b.WriteString(dimStyle.Render(fmt.Sprintf(" ... %d more above ...", winStart))) + b.WriteString("\n") + } + + for i := winStart; i < winEnd; i++ { + ij := jobs[i] + icon := jobIcon(ij.job) + instr := ij.job.Instructions + if len(instr) > 60 { + instr = instr[:60] + "..." + } + if instr == "" { + instr = "—" + } + + line := fmt.Sprintf(" %s %3d. %s %s", icon, ij.origIdx, dimStyle.Render(ij.job.Id), instr) + + if i == m.cursor { + line = reverseStyle.Render(line) + } + b.WriteString(line) + b.WriteString("\n") + } + + remaining := len(jobs) - winEnd + if remaining > 0 { + b.WriteString(dimStyle.Render(fmt.Sprintf(" ... %d more below ...", remaining))) + b.WriteString("\n") + } + + return b.String() +} + +func (m *simulateModel) renderDetail() string { + if m.run == nil { + return "" + } + var job *livekit.SimulationRun_Job + origIdx := 0 + for i, j := range m.run.Jobs { + if j.Id == m.detailJobID { + job = j + origIdx = i + 1 + break + } + } + if job == nil { + m.detailJobID = "" + return dimStyle.Render(" (job not found)\n") + } + + var b strings.Builder + b.WriteString("\n") + b.WriteString(fmt.Sprintf(" %s %s %s\n", + jobIcon(job), + boldStyle.Render(fmt.Sprintf("Job %d", origIdx)), + dimStyle.Render(job.Id), + )) + b.WriteString("\n") + + b.WriteString(boldStyle.Render(" Instructions:")) + b.WriteString("\n") + instr := job.Instructions + if instr == "" { + instr = "—" + } + for _, line := range strings.Split(instr, "\n") { + b.WriteString(" " + line + "\n") + } + b.WriteString("\n") + + b.WriteString(dimStyle.Bold(true).Render(" Expected:")) + b.WriteString("\n") + expect := job.AgentExpectations + if expect == "" { + expect = "—" + } + for _, line := range strings.Split(expect, "\n") { + b.WriteString(dimStyle.Render(" "+line) + "\n") + } + + if job.Error != "" { + b.WriteString("\n") + if job.Status == livekit.SimulationRun_Job_STATUS_COMPLETED { + b.WriteString(greenStyle.Bold(true).Render(" Result:")) + b.WriteString("\n") + for _, line := range strings.Split(job.Error, "\n") { + b.WriteString(greenStyle.Render(" "+line) + "\n") + } + } else { + b.WriteString(redStyle.Bold(true).Render(" Error:")) + b.WriteString("\n") + for _, line := range strings.Split(job.Error, "\n") { + b.WriteString(redStyle.Render(" "+line) + "\n") + } + } + } + return b.String() +} + +func (m *simulateModel) renderLogs() string { + var b strings.Builder + b.WriteString(dimStyle.Render(" " + strings.Repeat("─", 40))) + b.WriteString("\n") + logBudget := m.height - 15 + if logBudget < 3 { + logBudget = 3 + } + lines := m.agent.RecentLogs(logBudget) + for _, line := range lines { + b.WriteString(dimStyle.Render(" "+line) + "\n") + } + return b.String() +} + +func (m *simulateModel) renderHint() string { + if m.detailJobID != "" { + return dimStyle.Render(" ESC/q back · Ctrl+L logs") + } + hint := " ↑↓/Tab navigate · ENTER detail · ←→ filter · Ctrl+L logs" + if m.runFinished { + hint += " · q quit" + } + return dimStyle.Render(hint) +} + +func jobIcon(job *livekit.SimulationRun_Job) string { + switch job.Status { + case livekit.SimulationRun_Job_STATUS_COMPLETED: + return greenStyle.Render("✓") + case livekit.SimulationRun_Job_STATUS_FAILED: + return redStyle.Render("✗") + case livekit.SimulationRun_Job_STATUS_RUNNING: + return yellowStyle.Render("●") + default: + return dimStyle.Render("○") + } +} diff --git a/go.mod b/go.mod index d54940ce..06ca4ac0 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.25.0 require ( github.com/BurntSushi/toml v1.5.0 github.com/Masterminds/semver/v3 v3.4.0 + github.com/charmbracelet/bubbles v1.0.0 + github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/huh v0.7.1-0.20250818142555-c41a69ba6443 github.com/charmbracelet/huh/spinner v0.0.0-20250818142555-c41a69ba6443 github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 @@ -51,13 +53,14 @@ require ( github.com/catppuccin/go v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chainguard-dev/git-urls v1.0.2 // indirect - github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect - github.com/charmbracelet/bubbletea v1.3.6 // indirect - github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect - github.com/charmbracelet/x/ansi v0.9.3 // indirect - github.com/charmbracelet/x/cellbuf v0.0.13 // indirect + github.com/charmbracelet/colorprofile v0.4.1 // indirect + github.com/charmbracelet/x/ansi v0.11.6 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240809174237-9ab0ca04ce0c // indirect - github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.9.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.5.0 // indirect github.com/cloudflare/circl v1.6.1 // indirect github.com/containerd/console v1.0.5 // indirect github.com/containerd/containerd/api v1.10.0 // indirect @@ -114,12 +117,12 @@ require ( github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 // indirect github.com/livekit/mediatransportutil v0.0.0-20251128105421-19c7a7b81c22 // indirect github.com/livekit/psrpc v0.7.1 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/magefile/mage v1.15.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/moby/buildkit v0.26.2 // indirect github.com/moby/locker v1.0.1 // indirect @@ -202,3 +205,7 @@ require ( mvdan.cc/sh/v3 v3.12.0 // indirect sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect ) + +replace github.com/livekit/protocol => ../protocol + +replace github.com/livekit/server-sdk-go/v2 => ../server-sdk-go diff --git a/go.sum b/go.sum index 96fe90a1..0642ed61 100644 --- a/go.sum +++ b/go.sum @@ -71,22 +71,22 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chainguard-dev/git-urls v1.0.2 h1:pSpT7ifrpc5X55n4aTTm7FFUE+ZQHKiqpiwNkJrVcKQ= github.com/chainguard-dev/git-urls v1.0.2/go.mod h1:rbGgj10OS7UgZlbzdUQIQpT0k/D4+An04HJY7Ol+Y/o= -github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 h1:JFgG/xnwFfbezlUnFMJy0nusZvytYysV4SCS2cYbvws= -github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7/go.mod h1:ISC1gtLcVilLOf23wvTfoQuYbW2q0JevFxPfUzZ9Ybw= -github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU= -github.com/charmbracelet/bubbletea v1.3.6/go.mod h1:oQD9VCRQFF8KplacJLo28/jofOI2ToOfGYeFgBBxHOc= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= +github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= +github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= github.com/charmbracelet/huh v0.7.1-0.20250818142555-c41a69ba6443 h1:MrWAnG+pG1mmLMPbTkbv1h/epSDuKF9WHxdiAjO5aXY= github.com/charmbracelet/huh v0.7.1-0.20250818142555-c41a69ba6443/go.mod h1:5YVc+SlZ1IhQALxRPpkGwwEKftN/+OlJlnJYlDRFqN4= github.com/charmbracelet/huh/spinner v0.0.0-20250818142555-c41a69ba6443 h1:8ZNkdakTlK/oo0L6aIBdu72JpDK5LJVwYAYZ3yKEtRs= github.com/charmbracelet/huh/spinner v0.0.0-20250818142555-c41a69ba6443/go.mod h1:imftm8y+Db+rZ4Jcb6A7qJ0eOX78s9m84n8cdipC+R0= github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE= github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA= -github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0= -github.com/charmbracelet/x/ansi v0.9.3/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= -github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k= -github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= github.com/charmbracelet/x/conpty v0.1.0 h1:4zc8KaIcbiL4mghEON8D72agYtSeIgq8FSThSPQIb+U= github.com/charmbracelet/x/conpty v0.1.0/go.mod h1:rMFsDJoDwVmiYM10aD4bH2XiRgwI7NYJtQgl5yskjEQ= github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA= @@ -95,12 +95,18 @@ github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payR github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= github.com/charmbracelet/x/exp/strings v0.0.0-20240809174237-9ab0ca04ce0c h1:6IUwt5Pfsv8ugjei3FKMzKIEJMZJt5QN8nU2VN/IOPw= github.com/charmbracelet/x/exp/strings v0.0.0-20240809174237-9ab0ca04ce0c/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= -github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= -github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= github.com/charmbracelet/x/xpty v0.1.2 h1:Pqmu4TEJ8KeA9uSkISKMU3f+C1F6OGBn8ABuGlqCbtI= github.com/charmbracelet/x/xpty v0.1.2/go.mod h1:XK2Z0id5rtLWcpeNiMYBccNNBrP2IJnzHI0Lq13Xzq4= +github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= +github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb h1:EDmT6Q9Zs+SbUoc7Ik9EfrFqcylYqgPZ9ANSbTAntnE= @@ -269,14 +275,10 @@ github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 h1:9x+U2HGLrSw5AT github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20251128105421-19c7a7b81c22 h1:dzCBxOGLLWVtQhL7OYK2EGN+5Q+23Mq/jfz4vQisirA= github.com/livekit/mediatransportutil v0.0.0-20251128105421-19c7a7b81c22/go.mod h1:mSNtYzSf6iY9xM3UX42VEI+STHvMgHmrYzEHPcdhB8A= -github.com/livekit/protocol v1.44.1-0.20260223200831-a71190b6850a h1:NiNcePPTmHoGvAfp1b6EYbgLBP+WkQoPzJFul3skPXo= -github.com/livekit/protocol v1.44.1-0.20260223200831-a71190b6850a/go.mod h1:BLJHYHErQTu3+fnmfGrzN6CbHxNYiooFIIYGYxXxotw= github.com/livekit/psrpc v0.7.1 h1:ms37az0QTD3UXIWuUC5D/SkmKOlRMVRsI261eBWu/Vw= github.com/livekit/psrpc v0.7.1/go.mod h1:bZ4iHFQptTkbPnB0LasvRNu/OBYXEu1NA6O5BMFo9kk= -github.com/livekit/server-sdk-go/v2 v2.13.4-0.20260223172816-77b4264bca63 h1:icaw405HodmKB7tkhDwgAxgLAM5VA2o7pYQYC9677B0= -github.com/livekit/server-sdk-go/v2 v2.13.4-0.20260223172816-77b4264bca63/go.mod h1:X1vsxK4ZN4ds1BxuJkFrHNHRnkyu/kMICSlSfpuzvjo= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -285,8 +287,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= github.com/moby/buildkit v0.26.2 h1:EIh5j0gzRsCZmQzvgNNWzSDbuKqwUIiBH7ssqLv8RU8= @@ -402,7 +404,6 @@ github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++ github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rodaine/protogofakeit v0.1.1 h1:ZKouljuRM3A+TArppfBqnH8tGZHOwM/pjvtXe9DaXH8= diff --git a/pkg/agentfs/detect.go b/pkg/agentfs/detect.go index 55a0ceae..4d8d1c05 100644 --- a/pkg/agentfs/detect.go +++ b/pkg/agentfs/detect.go @@ -16,7 +16,10 @@ package agentfs import ( "errors" + "fmt" "io/fs" + "os" + "path/filepath" "github.com/livekit/livekit-cli/v2/pkg/util" "github.com/pelletier/go-toml" @@ -115,3 +118,26 @@ func DetectProjectType(dir fs.FS) (ProjectType, error) { return ProjectTypeUnknown, errors.New("expected package.json, requirements.txt, pyproject.toml, or lock files") } + +// DetectProjectRoot walks up from dir to find a directory containing project +// files (pyproject.toml, requirements.txt, package.json, etc). Returns the +// absolute path to the project root and the detected project type. +func DetectProjectRoot(dir string) (string, ProjectType, error) { + absDir, err := filepath.Abs(dir) + if err != nil { + return "", ProjectTypeUnknown, err + } + + for { + pt, err := DetectProjectType(os.DirFS(absDir)) + if err == nil { + return absDir, pt, nil + } + + parent := filepath.Dir(absDir) + if parent == absDir { + return "", ProjectTypeUnknown, fmt.Errorf("could not detect project type in %s or any parent directory", dir) + } + absDir = parent + } +} diff --git a/pkg/apm/apm.go b/pkg/apm/apm.go index a125a1e4..b8f219d6 100644 --- a/pkg/apm/apm.go +++ b/pkg/apm/apm.go @@ -132,6 +132,41 @@ func (a *APM) StreamDelayMs() int { return int(C.apm_stream_delay_ms(a.handle)) } +// Stats holds AEC statistics from the WebRTC APM. +type Stats struct { + EchoReturnLoss float64 // ERL in dB (higher = more echo removed) + EchoReturnLossEnhancement float64 // ERLE in dB (higher = better cancellation) + DivergentFilterFraction float64 // 0-1, fraction of time filter is divergent + DelayMs int // Estimated echo path delay + ResidualEchoLikelihood float64 // 0-1, likelihood of residual echo + HasERL bool + HasERLE bool + HasDelay bool + HasResidualEcho bool + HasDivergent bool +} + +// GetStats returns the current AEC statistics. +func (a *APM) GetStats() Stats { + if a.handle == nil { + return Stats{} + } + var cs C.ApmStats + C.apm_get_stats(a.handle, &cs) + return Stats{ + EchoReturnLoss: float64(cs.echo_return_loss), + EchoReturnLossEnhancement: float64(cs.echo_return_loss_enhancement), + DivergentFilterFraction: float64(cs.divergent_filter_fraction), + DelayMs: int(cs.delay_ms), + ResidualEchoLikelihood: float64(cs.residual_echo_likelihood), + HasERL: cs.has_erl != 0, + HasERLE: cs.has_erle != 0, + HasDelay: cs.has_delay != 0, + HasResidualEcho: cs.has_residual_echo != 0, + HasDivergent: cs.has_divergent != 0, + } +} + func (a *APM) Close() { if a.handle != nil { C.apm_destroy(a.handle) diff --git a/pkg/apm/bridge.cpp b/pkg/apm/bridge.cpp index 41f6de3e..7f790109 100644 --- a/pkg/apm/bridge.cpp +++ b/pkg/apm/bridge.cpp @@ -73,4 +73,25 @@ int apm_stream_delay_ms(ApmHandle h) { return inst->apm->stream_delay_ms(); } +void apm_get_stats(ApmHandle h, ApmStats* out) { + if (!h || !out) return; + auto* inst = static_cast(h); + auto stats = inst->apm->GetStatistics(); + + out->has_erl = stats.echo_return_loss.has_value() ? 1 : 0; + out->echo_return_loss = stats.echo_return_loss.value_or(0.0); + + out->has_erle = stats.echo_return_loss_enhancement.has_value() ? 1 : 0; + out->echo_return_loss_enhancement = stats.echo_return_loss_enhancement.value_or(0.0); + + out->has_divergent = stats.divergent_filter_fraction.has_value() ? 1 : 0; + out->divergent_filter_fraction = stats.divergent_filter_fraction.value_or(0.0); + + out->has_delay = stats.delay_ms.has_value() ? 1 : 0; + out->delay_ms = stats.delay_ms.value_or(0); + + out->has_residual_echo = stats.residual_echo_likelihood.has_value() ? 1 : 0; + out->residual_echo_likelihood = stats.residual_echo_likelihood.value_or(0.0); +} + } // extern "C" diff --git a/pkg/apm/bridge.h b/pkg/apm/bridge.h index 90355dcd..79f8906c 100644 --- a/pkg/apm/bridge.h +++ b/pkg/apm/bridge.h @@ -28,6 +28,23 @@ void apm_set_stream_delay_ms(ApmHandle h, int delay_ms); // Get the current stream delay in milliseconds. int apm_stream_delay_ms(ApmHandle h); +// AEC statistics returned by apm_get_stats. +typedef struct { + int has_erl; + double echo_return_loss; // ERL in dB + int has_erle; + double echo_return_loss_enhancement; // ERLE in dB + int has_divergent; + double divergent_filter_fraction; + int has_delay; + int delay_ms; + int has_residual_echo; + double residual_echo_likelihood; +} ApmStats; + +// Get current AEC statistics. +void apm_get_stats(ApmHandle h, ApmStats* out); + #ifdef __cplusplus } #endif diff --git a/pkg/console/fft.go b/pkg/console/fft.go new file mode 100644 index 00000000..b1ac4983 --- /dev/null +++ b/pkg/console/fft.go @@ -0,0 +1,64 @@ +//go:build console + +package console + +import ( + "math" + "math/cmplx" +) + +// fft computes an in-place radix-2 Cooley-Tukey FFT. +func fft(a []complex128) { + n := len(a) + if n <= 1 { + return + } + + // Bit-reversal permutation + for i, j := 1, 0; i < n; i++ { + bit := n >> 1 + for ; j&bit != 0; bit >>= 1 { + j ^= bit + } + j ^= bit + if i < j { + a[i], a[j] = a[j], a[i] + } + } + + // Butterfly stages + for length := 2; length <= n; length <<= 1 { + angle := -2 * math.Pi / float64(length) + wn := cmplx.Exp(complex(0, angle)) + for i := 0; i < n; i += length { + w := complex(1, 0) + for j := 0; j < length/2; j++ { + u := a[i+j] + v := w * a[i+j+length/2] + a[i+j] = u + v + a[i+j+length/2] = u - v + w *= wn + } + } + } +} + +// rfft computes the real FFT of x, returning n/2+1 complex bins +// where n is the next power of 2 >= len(x). +func rfft(x []float64) ([]complex128, int) { + n := nextPow2(len(x)) + buf := make([]complex128, n) + for i, v := range x { + buf[i] = complex(v, 0) + } + fft(buf) + return buf[:n/2+1], n +} + +func nextPow2(n int) int { + p := 1 + for p < n { + p <<= 1 + } + return p +} diff --git a/pkg/console/pipeline.go b/pkg/console/pipeline.go index 8b6bff3f..8622afcb 100644 --- a/pkg/console/pipeline.go +++ b/pkg/console/pipeline.go @@ -3,30 +3,44 @@ // Package console implements the audio pipeline for the lk console command. // It connects microphone input and speaker output via PortAudio, applies // WebRTC audio processing (echo cancellation, noise suppression), and -// communicates with an agent over TCP using raw PCM frames. +// communicates with an agent over TCP using protobuf-framed SessionMessages. +// +// Architecture (3 goroutines, matching the Python console's PortAudio model): +// +// micLoop — reads PortAudio input into the capture ring buffer. +// speakerLoop — reads both rings, runs ProcessRender + ProcessCapture in +// lockstep, writes to speakers, sends capture to agent. +// Paced by outputStream.Write at the hardware output rate. +// tcpReader — reads TCP messages: audio → playback ring, events → TUI. +// +// All APM calls happen in speakerLoop, so they are single-threaded and +// guaranteed 1:1. package console import ( "context" "encoding/binary" - "io" - "log" "math" "net" "sync" + "time" + + agent "github.com/livekit/protocol/livekit/agent" "github.com/livekit/livekit-cli/v2/pkg/apm" "github.com/livekit/livekit-cli/v2/pkg/portaudio" ) const ( - SampleRate = 48000 - Channels = 1 - FrameDurationMs = 20 - SamplesPerFrame = SampleRate * FrameDurationMs / 1000 // 960 - APMFrameSamples = SampleRate / 100 // 480 (10ms) - RingBufferFrames = 10 // ~200ms buffer - NumFFTBands = 14 + SampleRate = 48000 + Channels = 1 + FrameDurationMs = 30 + SamplesPerFrame = SampleRate * FrameDurationMs / 1000 // 1440 + APMFrameSamples = SampleRate / 100 // 480 (10ms) + NumFFTBands = 14 + + CaptureRingFrames = 50 // ~1.5s — small, just absorbs jitter between mic and speaker loops + PlaybackRingFrames = 4000 // ~120s — large, TTS pushes faster than real-time ) type AudioPipeline struct { @@ -34,14 +48,27 @@ type AudioPipeline struct { outputStream *portaudio.Stream apmInst *apm.APM conn net.Conn + connMu sync.Mutex // protects writes to conn captureRing *RingBuffer playbackRing *RingBuffer + // Events channel receives SessionEvents from the agent for the TUI. + Events chan *agent.SessionEvent + + // ready is closed when the agent session is established (first TCP message). + ready chan struct{} + readyOnce sync.Once + + // flushCancel cancels the current waitForDrainAndAck goroutine. + // Only accessed from the tcpReader goroutine. + flushCancel context.CancelFunc + mu sync.Mutex fftBands [NumFFTBands]float64 muted bool level float64 // capture level in dB + playing bool // true when outputting real audio (not silence) cancel context.CancelFunc wg sync.WaitGroup @@ -73,7 +100,7 @@ func NewPipeline(cfg PipelineConfig) (*AudioPipeline, error) { apmCfg.RenderChannels = Channels apmInst, err = apm.NewAPM(apmCfg) if err != nil { - log.Printf("warning: failed to create APM, running without AEC: %v", err) + apmInst = nil // run without AEC } } @@ -84,33 +111,35 @@ func NewPipeline(cfg PipelineConfig) (*AudioPipeline, error) { apmInst.SetStreamDelayMs(delayMs) } - p := &AudioPipeline{ + return &AudioPipeline{ inputStream: inputStream, outputStream: outputStream, apmInst: apmInst, conn: cfg.Conn, - captureRing: NewRingBuffer(SamplesPerFrame * RingBufferFrames), - playbackRing: NewRingBuffer(SamplesPerFrame * RingBufferFrames), - } - return p, nil + captureRing: NewRingBuffer(SamplesPerFrame * CaptureRingFrames), + playbackRing: NewRingBuffer(SamplesPerFrame * PlaybackRingFrames), + Events: make(chan *agent.SessionEvent, 64), + ready: make(chan struct{}), + }, nil } func (p *AudioPipeline) Start(ctx context.Context) error { ctx, p.cancel = context.WithCancel(ctx) - if err := p.inputStream.Start(); err != nil { + // Start output before input so the render path is running when the + // first capture frame arrives. + if err := p.outputStream.Start(); err != nil { return err } - if err := p.outputStream.Start(); err != nil { - p.inputStream.Stop() + if err := p.inputStream.Start(); err != nil { + p.outputStream.Stop() return err } - p.wg.Add(4) - go p.captureReader(ctx) - go p.captureWorker(ctx) - go p.playbackWorker(ctx) - go p.playbackWriter(ctx) + p.wg.Add(3) + go p.micLoop(ctx) + go p.speakerLoop(ctx) + go p.tcpReader(ctx) <-ctx.Done() return nil @@ -121,13 +150,10 @@ func (p *AudioPipeline) Stop() { p.cancel() } - // Stop PortAudio streams first (prevents callbacks into dead goroutines) p.inputStream.Stop() p.outputStream.Stop() - - // Wake up any blocked ring buffer readers + p.conn.Close() p.captureRing.cond.Broadcast() - p.playbackRing.cond.Broadcast() p.wg.Wait() @@ -137,8 +163,18 @@ func (p *AudioPipeline) Stop() { if p.apmInst != nil { p.apmInst.Close() } +} + +func (p *AudioPipeline) writeMessage(msg *agent.SessionMessage) error { + p.connMu.Lock() + defer p.connMu.Unlock() + return WriteSessionMessage(p.conn, msg) +} - WriteMessage(p.conn, MsgEOF, nil) +func (p *AudioPipeline) SendRequest(req *agent.SessionRequest) error { + return p.writeMessage(&agent.SessionMessage{ + Message: &agent.SessionMessage_Request{Request: req}, + }) } func (p *AudioPipeline) SetMuted(muted bool) { @@ -165,246 +201,346 @@ func (p *AudioPipeline) FFTBands() [NumFFTBands]float64 { return p.fftBands } -func (p *AudioPipeline) captureReader(ctx context.Context) { +func (p *AudioPipeline) IsPlaying() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.playing +} + +func (p *AudioPipeline) AECStats() *apm.Stats { + if p.apmInst == nil { + return nil + } + s := p.apmInst.GetStats() + return &s +} + +// micLoop reads mic input at hardware rate and writes to the capture ring. +// Muting is applied here so speakerLoop always sees clean data. +func (p *AudioPipeline) micLoop(ctx context.Context) { defer p.wg.Done() buf := make([]int16, SamplesPerFrame*Channels) + for { - select { - case <-ctx.Done(): + if ctx.Err() != nil { return - default: } if err := p.inputStream.Read(buf); err != nil { if ctx.Err() != nil { return } - log.Printf("capture read error: %v", err) - return + continue + } + + p.mu.Lock() + muted := p.muted + p.mu.Unlock() + + if muted { + for i := range buf { + buf[i] = 0 + } } + p.captureRing.Write(buf) } } -func (p *AudioPipeline) captureWorker(ctx context.Context) { +// speakerLoop runs all APM processing and output. Paced by outputStream.Write +// at the hardware output rate (~30ms). Each iteration: +// 1. Reads capture from captureRing (non-blocking, silence if empty) +// 2. Reads playback from playbackRing (non-blocking, silence if empty) +// 3. ProcessRender then ProcessCapture (single-threaded, 1:1) +// 4. Writes playback to speakers +// 5. Sends processed capture to agent +func (p *AudioPipeline) speakerLoop(ctx context.Context) { defer p.wg.Done() - frame := make([]int16, SamplesPerFrame*Channels) + captureBuf := make([]int16, SamplesPerFrame*Channels) + playbackBuf := make([]int16, SamplesPerFrame*Channels) apmBuf := make([]int16, APMFrameSamples*Channels) + ready := false for { - select { - case <-ctx.Done(): + if ctx.Err() != nil { return - default: } - if !p.captureRing.Read(frame) { - return + // Read capture (non-blocking); pad remainder with silence. + cn := p.captureRing.ReadAvailable(captureBuf) + for i := cn; i < len(captureBuf); i++ { + captureBuf[i] = 0 + } + + // Read playback (non-blocking); pad remainder with silence. + pn := p.playbackRing.ReadAvailable(playbackBuf) + for i := pn; i < len(playbackBuf); i++ { + playbackBuf[i] = 0 } p.mu.Lock() - muted := p.muted + p.playing = pn > 0 p.mu.Unlock() - if muted { - for i := range frame { - frame[i] = 0 - } - } - - // Process through APM in 10ms chunks + // ProcessRender then ProcessCapture — both in this goroutine, + // right next to each other, no mutex needed. if p.apmInst != nil { for i := 0; i < SamplesPerFrame; i += APMFrameSamples { - copy(apmBuf, frame[i:i+APMFrameSamples]) - if err := p.apmInst.ProcessCapture(apmBuf); err != nil { - log.Printf("APM capture error: %v", err) - } - copy(frame[i:], apmBuf) + copy(apmBuf, playbackBuf[i:i+APMFrameSamples]) + _ = p.apmInst.ProcessRender(apmBuf) + + copy(apmBuf, captureBuf[i:i+APMFrameSamples]) + _ = p.apmInst.ProcessCapture(apmBuf) + copy(captureBuf[i:], apmBuf) } } - p.computeMetrics(frame) - - payload := SamplesToBytes(frame) - if err := WriteMessage(p.conn, MsgCapture, payload); err != nil { + // Write playback to speakers — blocks at hardware rate. + if err := p.outputStream.Write(playbackBuf); err != nil { if ctx.Err() != nil { return } - log.Printf("TCP send error: %v", err) - return } + + // Send processed capture to agent (only after session is ready). + if !ready { + select { + case <-p.ready: + ready = true + default: + continue + } + } + + p.computeMetrics(captureBuf) + + _ = p.writeMessage(&agent.SessionMessage{ + Message: &agent.SessionMessage_AudioInput{ + AudioInput: &agent.AudioFrame{ + Data: SamplesToBytes(captureBuf), + SampleRate: SampleRate, + NumChannels: Channels, + SamplesPerChannel: uint32(SamplesPerFrame), + }, + }, + }) } } -func (p *AudioPipeline) playbackWorker(ctx context.Context) { +// tcpReader reads messages from the agent over TCP and dispatches them. +func (p *AudioPipeline) tcpReader(ctx context.Context) { defer p.wg.Done() - apmBuf := make([]int16, APMFrameSamples*Channels) for { - select { - case <-ctx.Done(): + msg, err := ReadSessionMessage(p.conn) + if err != nil { return - default: } - msgType, payload, err := ReadMessage(p.conn) - if err != nil { - if ctx.Err() != nil { - return - } - if err == io.EOF { - log.Printf("Agent disconnected") - return + p.readyOnce.Do(func() { close(p.ready) }) + + switch m := msg.Message.(type) { + case *agent.SessionMessage_AudioOutput: + p.playbackRing.Write(BytesToSamples(m.AudioOutput.Data)) + + case *agent.SessionMessage_Event: + select { + case p.Events <- m.Event: + default: } - // Timeout is expected when no data is flowing - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - continue + + case *agent.SessionMessage_AudioPlaybackClear: + if p.flushCancel != nil { + p.flushCancel() + p.flushCancel = nil } - log.Printf("TCP recv error: %v", err) - return - } + p.playbackRing.Reset() - switch msgType { - case MsgRender: - samples := BytesToSamples(payload) - - if p.apmInst != nil { - for i := 0; i < len(samples); i += APMFrameSamples { - end := i + APMFrameSamples - if end > len(samples) { - end = len(samples) - } - chunk := samples[i:end] - if len(chunk) == APMFrameSamples { - copy(apmBuf, chunk) - if err := p.apmInst.ProcessRender(apmBuf); err != nil { - log.Printf("APM render error: %v", err) - } - copy(chunk, apmBuf) - } + case *agent.SessionMessage_AudioPlaybackFlush: + if p.flushCancel != nil { + p.flushCancel() + } + flushCtx, cancel := context.WithCancel(ctx) + p.flushCancel = cancel + go p.waitForDrainAndAck(flushCtx) + + case *agent.SessionMessage_Response: + if ev := responseToEvent(m.Response); ev != nil { + select { + case p.Events <- ev: + default: } } + } + } +} - p.playbackRing.Write(samples) - - case MsgEOF: - log.Printf("Agent sent EOF") - return - - case MsgConfig: - // TODO: handle config messages +func responseToEvent(resp *agent.SessionResponse) *agent.SessionEvent { + if resp == nil { + return nil + } + if r, ok := resp.Response.(*agent.SessionResponse_SendMessage); ok { + if r.SendMessage != nil && len(r.SendMessage.Items) > 0 { + return &agent.SessionEvent{ + Event: &agent.SessionEvent_ConversationItemAdded{ + ConversationItemAdded: &agent.ConversationItemAddedEvent{ + Item: r.SendMessage.Items[len(r.SendMessage.Items)-1], + }, + }, + } } } + return nil } -func (p *AudioPipeline) playbackWriter(ctx context.Context) { - defer p.wg.Done() - buf := make([]int16, SamplesPerFrame*Channels) - silence := make([]int16, SamplesPerFrame*Channels) +func (p *AudioPipeline) sendPlaybackFinished() { + _ = p.writeMessage(&agent.SessionMessage{ + Message: &agent.SessionMessage_AudioPlaybackFinished{ + AudioPlaybackFinished: &agent.AudioPlaybackFinished{}, + }, + }) +} - for { +func (p *AudioPipeline) waitForDrainAndAck(ctx context.Context) { + for p.playbackRing.Available() > 0 { select { case <-ctx.Done(): return default: } - - // If not enough data, play silence to avoid blocking - if p.playbackRing.Available() < SamplesPerFrame*Channels { - if err := p.outputStream.Write(silence); err != nil { - if ctx.Err() != nil { - return - } - log.Printf("playback write error: %v", err) - return - } - continue - } - - p.playbackRing.Read(buf) - if err := p.outputStream.Write(buf); err != nil { - if ctx.Err() != nil { - return - } - log.Printf("playback write error: %v", err) - return - } + time.Sleep(5 * time.Millisecond) } + select { + case <-ctx.Done(): + return + default: + } + p.sendPlaybackFinished() } func (p *AudioPipeline) computeMetrics(samples []int16) { - var sum float64 - for _, s := range samples { + n := len(samples) + sr := float64(SampleRate) + + // Convert to float64, normalize, apply Hanning window + x := make([]float64, n) + for i, s := range samples { v := float64(s) / 32768.0 - sum += v * v + w := 0.5 * (1 - math.Cos(2*math.Pi*float64(i)/float64(n))) + x[i] = v * w } - rms := math.Sqrt(sum / float64(len(samples))) - db := 20 * math.Log10(rms+1e-10) - // Simple band-energy estimation using overlapping windows - // Not a real FFT, but good enough for a visualizer - var bands [NumFFTBands]float64 - bandSize := len(samples) / NumFFTBands - if bandSize < 1 { - bandSize = 1 + // Real FFT + X, nfft := rfft(x) + + // Magnitude spectrum, scaled by 2/n + mag := make([]float64, len(X)) + scale := 2.0 / float64(n) + for i, c := range X { + r, im := real(c), imag(c) + mag[i] = math.Sqrt(r*r+im*im) * scale } - for b := 0; b < NumFFTBands; b++ { - start := b * bandSize - end := start + bandSize - if end > len(samples) { - end = len(samples) + mag[0] *= 0.5 + if n%2 == 0 { + mag[len(mag)-1] *= 0.5 + } + + // Geometric frequency band edges: 20 Hz → Nyquist*0.96 + nb := NumFFTBands + nyquist := sr * 0.5 * 0.96 + logLow := math.Log(20.0) + logHigh := math.Log(nyquist) + edges := make([]float64, nb+1) + for i := 0; i <= nb; i++ { + edges[i] = math.Exp(logLow + float64(i)*(logHigh-logLow)/float64(nb)) + } + + // Bin power into frequency bands + binFreq := sr / float64(nfft) + sump := make([]float64, nb) + cnts := make([]float64, nb) + for i, m := range mag { + freq := float64(i) * binFreq + // Find band via edges (equivalent to np.digitize - 1, clipped) + band := nb - 1 + for b := 1; b <= nb; b++ { + if freq < edges[b] { + band = b - 1 + break + } } - var bandSum float64 - for i := start; i < end; i++ { - v := float64(samples[i]) / 32768.0 - bandSum += v * v + if band < 0 { + band = 0 } - bandRMS := math.Sqrt(bandSum / float64(end-start)) - bands[b] = math.Min(bandRMS*8.0, 1.0) + sump[band] += m * m + cnts[band]++ } - p.mu.Lock() - p.level = db - p.fftBands = bands - p.mu.Unlock() -} - -func ComputeFFTBands(samples []int16) [NumFFTBands]float64 { + // Mean power → dB → normalize to [0,1] + const floorDB, hotDB = -70.0, -20.0 var bands [NumFFTBands]float64 - bandSize := len(samples) / NumFFTBands - if bandSize < 1 { - bandSize = 1 - } - for b := 0; b < NumFFTBands; b++ { - start := b * len(samples) / NumFFTBands - end := (b + 1) * len(samples) / NumFFTBands - if end > len(samples) { - end = len(samples) + for b := 0; b < nb; b++ { + c := cnts[b] + if c == 0 { + c = 1 } - var sum float64 - for i := start; i < end; i++ { - v := float64(samples[i]) / 32768.0 - sum += v * v + pmean := sump[b] / c + db := 10.0 * math.Log10(pmean + 1e-12) + lev := (db - floorDB) / (hotDB - floorDB) + lev = math.Max(0, math.Min(1, lev)) + // Power-law compression + lev = math.Max(math.Pow(lev, 0.75)-0.02, 0) + bands[b] = lev + } + + // Peak normalization (cap scale at 3x to avoid blowing up silence) + peak := 0.0 + for _, v := range bands { + if v > peak { + peak = v } - rms := math.Sqrt(sum / float64(end-start)) - bands[b] = math.Min(rms*8.0, 1.0) } - return bands -} + normScale := math.Min(0.95/(peak+1e-6), 3.0) + for b := range bands { + bands[b] = math.Min(bands[b]*normScale, 1.0) + } + + // Exponential decay smoothing (~100ms time constant) + decay := math.Exp(-float64(n) / sr / 0.1) -func ComputeLevelDB(samples []int16) float64 { + // RMS level in dB var sum float64 for _, s := range samples { v := float64(s) / 32768.0 sum += v * v } - rms := math.Sqrt(sum / float64(len(samples))) - return 20 * math.Log10(rms+1e-10) + rms := math.Sqrt(sum / float64(n)) + db := 20 * math.Log10(rms+1e-10) + + p.mu.Lock() + for b := 0; b < nb; b++ { + if bands[b] > p.fftBands[b]*decay { + p.fftBands[b] = bands[b] + } else { + p.fftBands[b] *= decay + } + } + p.level = db + p.mu.Unlock() } -func Int16LEToBytes(samples []int16) []byte { +func SamplesToBytes(samples []int16) []byte { buf := make([]byte, len(samples)*2) for i, s := range samples { binary.LittleEndian.PutUint16(buf[i*2:], uint16(s)) } return buf } + +func BytesToSamples(data []byte) []int16 { + n := len(data) / 2 + samples := make([]int16, n) + for i := range samples { + samples[i] = int16(binary.LittleEndian.Uint16(data[i*2:])) + } + return samples +} diff --git a/pkg/console/ringbuffer.go b/pkg/console/ringbuffer.go index 73db5bb7..29648bb4 100644 --- a/pkg/console/ringbuffer.go +++ b/pkg/console/ringbuffer.go @@ -8,7 +8,7 @@ import ( ) // RingBuffer is a SPSC ring buffer for int16 audio samples. -// Overwrites oldest data when full (lossy). +// When the writer outruns the reader, the reader skips ahead to avoid stale data. type RingBuffer struct { buf []int16 size int @@ -29,16 +29,44 @@ func NewRingBuffer(size int) *RingBuffer { func (rb *RingBuffer) Write(samples []int16) int { n := len(samples) + if n > rb.size { + samples = samples[n-rb.size:] + n = rb.size + } w := int(rb.w.Load()) for i := 0; i < n; i++ { rb.buf[(w+i)%rb.size] = samples[i] } rb.w.Add(int64(n)) - rb.cond.Signal() return n } +// ReadAvailable copies up to len(out) available samples into out (non-blocking). +// Returns the number of samples actually copied. +func (rb *RingBuffer) ReadAvailable(out []int16) int { + avail := int(rb.w.Load() - rb.r.Load()) + if avail <= 0 { + return 0 + } + // If writer has lapped us, skip ahead + if avail > rb.size { + skip := int64(avail - rb.size) + rb.r.Add(skip) + avail = rb.size + } + n := len(out) + if n > avail { + n = avail + } + r := int(rb.r.Load()) + for i := 0; i < n; i++ { + out[i] = rb.buf[(r+i)%rb.size] + } + rb.r.Add(int64(n)) + return n +} + // Read blocks until len(out) samples are available, then copies them. func (rb *RingBuffer) Read(out []int16) bool { needed := len(out) @@ -53,6 +81,11 @@ func (rb *RingBuffer) Read(out []int16) bool { rb.mu.Unlock() continue } + if avail > rb.size { + skip := int64(avail - rb.size) + rb.r.Add(skip) + avail = rb.size + } toCopy := needed - copied if toCopy > avail { toCopy = avail @@ -71,7 +104,27 @@ func (rb *RingBuffer) Available() int { return int(rb.w.Load() - rb.r.Load()) } +// WaitForData blocks until samples are available in the buffer. +// Returns true if data is available, false if woken up with no data +// (e.g., after Reset or Broadcast for shutdown). +func (rb *RingBuffer) WaitForData() bool { + if rb.w.Load()-rb.r.Load() > 0 { + return true + } + rb.mu.Lock() + for rb.w.Load()-rb.r.Load() <= 0 { + rb.cond.Wait() + // After wakeup, re-check. If still empty (Reset/shutdown), return false. + if rb.w.Load()-rb.r.Load() <= 0 { + rb.mu.Unlock() + return false + } + } + rb.mu.Unlock() + return true +} + func (rb *RingBuffer) Reset() { - rb.r.Store(0) - rb.w.Store(0) + rb.r.Store(rb.w.Load()) + rb.cond.Broadcast() } diff --git a/pkg/console/tcp.go b/pkg/console/tcp.go index b57e1741..4df8320e 100644 --- a/pkg/console/tcp.go +++ b/pkg/console/tcp.go @@ -9,15 +9,10 @@ import ( "io" "net" "sync" - "time" -) -// Message types for the TCP protocol. -const ( - MsgCapture byte = 0x01 // capture audio (CLI → Agent) - MsgRender byte = 0x02 // render audio (Agent → CLI) - MsgConfig byte = 0x03 // config (bidirectional) - MsgEOF byte = 0x04 // graceful shutdown + "google.golang.org/protobuf/proto" + + agent "github.com/livekit/protocol/livekit/agent" ) type TCPServer struct { @@ -64,6 +59,13 @@ func (s *TCPServer) Accept() (net.Conn, error) { return conn, nil } +// Conn returns the accepted connection, or nil if none. +func (s *TCPServer) Conn() net.Conn { + s.mu.Lock() + defer s.mu.Unlock() + return s.conn +} + func (s *TCPServer) Close() error { s.mu.Lock() defer s.mu.Unlock() @@ -78,69 +80,47 @@ func (s *TCPServer) Close() error { return errors.Join(errs...) } -// WriteMessage sends a framed message: [1 byte type][4 bytes BE length][payload]. -func WriteMessage(w io.Writer, msgType byte, payload []byte) error { - header := [5]byte{msgType} - binary.BigEndian.PutUint32(header[1:], uint32(len(payload))) - if _, err := w.Write(header[:]); err != nil { - return err - } - if len(payload) > 0 { - _, err := w.Write(payload) - return err +// WriteSessionMessage sends a protobuf-framed message: [4 bytes BE length][proto bytes]. +// Uses a single write call to avoid split TCP segments. +func WriteSessionMessage(w io.Writer, msg *agent.SessionMessage) error { + data, err := proto.Marshal(msg) + if err != nil { + return fmt.Errorf("console tcp: marshal: %w", err) } - return nil -} -// ReadMessage reads a framed message. Returns (type, payload, error). -// Sets a read deadline to detect stale connections. -func ReadMessage(r io.Reader) (byte, []byte, error) { - if conn, ok := r.(net.Conn); ok { - conn.SetReadDeadline(time.Now().Add(5 * time.Second)) - } + // Combine header + payload into a single write + buf := make([]byte, 4+len(data)) + binary.BigEndian.PutUint32(buf[:4], uint32(len(data))) + copy(buf[4:], data) + _, err = w.Write(buf) + return err +} - var header [5]byte +// ReadSessionMessage reads a protobuf-framed message: [4 bytes BE length][proto bytes]. +// Blocks until a complete message is available. +func ReadSessionMessage(r io.Reader) (*agent.SessionMessage, error) { + var header [4]byte if _, err := io.ReadFull(r, header[:]); err != nil { - return 0, nil, err + return nil, err } - msgType := header[0] - length := binary.BigEndian.Uint32(header[1:]) + length := binary.BigEndian.Uint32(header[:]) if length > 1<<20 { // 1MB sanity limit - return 0, nil, fmt.Errorf("console tcp: message too large: %d bytes", length) - } - - if length == 0 { - return msgType, nil, nil - } - - // Reset deadline for payload read - if conn, ok := r.(net.Conn); ok { - conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + return nil, fmt.Errorf("console tcp: message too large: %d bytes", length) } - payload := make([]byte, length) - if _, err := io.ReadFull(r, payload); err != nil { - return 0, nil, fmt.Errorf("console tcp: partial message: %w", err) + data := make([]byte, length) + if length > 0 { + if _, err := io.ReadFull(r, data); err != nil { + return nil, fmt.Errorf("console tcp: partial message: %w", err) + } } - return msgType, payload, nil -} - -func SamplesToBytes(samples []int16) []byte { - buf := make([]byte, len(samples)*2) - for i, s := range samples { - binary.LittleEndian.PutUint16(buf[i*2:], uint16(s)) + msg := &agent.SessionMessage{} + if err := proto.Unmarshal(data, msg); err != nil { + return nil, fmt.Errorf("console tcp: unmarshal: %w", err) } - return buf + return msg, nil } -func BytesToSamples(data []byte) []int16 { - n := len(data) / 2 - samples := make([]int16, n) - for i := range samples { - samples[i] = int16(binary.LittleEndian.Uint16(data[i*2:])) - } - return samples -}