diff --git a/pkg/submit/direct.go b/pkg/submit/direct.go index 9040759..b3b2daf 100644 --- a/pkg/submit/direct.go +++ b/pkg/submit/direct.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strconv" "strings" "sync" "time" @@ -19,7 +20,10 @@ const ( maxSequenceRetryRounds = 2 ) -var errSequenceMismatch = errors.New("account sequence mismatch") +var ( + errSequenceMismatch = errors.New("account sequence mismatch") + errTooManyInFlight = errors.New("too many in-flight submissions") +) // DirectConfig contains the fixed submission settings Apex owns for direct // celestia-app writes. @@ -42,6 +46,12 @@ type DirectSubmitter struct { pollInterval time.Duration feeDenom string mu sync.Mutex + inFlight int + accountNumber uint64 + nextSequence uint64 + sequenceReady bool + pendingSequences map[string]uint64 + maxInFlight int } // NewDirectSubmitter builds a concrete single-account submitter. @@ -77,6 +87,7 @@ func NewDirectSubmitter(app AppClient, signer *Signer, cfg DirectConfig) (*Direc confirmationTimeout: cfg.ConfirmationTimeout, pollInterval: defaultPollInterval, feeDenom: defaultFeeDenom, + pendingSequences: make(map[string]uint64), }, nil } @@ -87,29 +98,23 @@ func (s *DirectSubmitter) Close() error { return s.app.Close() } -// Submit serializes submissions for the configured signer so sequence handling -// stays bounded and explicit in v1. +// Submit serializes sequence reservation and broadcast for the configured +// signer, then waits for confirmation without blocking the next nonce. func (s *DirectSubmitter) Submit(ctx context.Context, req *Request) (*Result, error) { if err := validateSubmitRequest(req); err != nil { return nil, err } + if err := s.startSubmission(); err != nil { + return nil, err + } + defer s.finishSubmission() - s.mu.Lock() - defer s.mu.Unlock() - - var lastErr error - for range maxSequenceRetryRounds { - result, err := s.submitOnce(ctx, req) - if err == nil { - return result, nil - } - lastErr = err - if !errors.Is(err, errSequenceMismatch) { - return nil, err - } + broadcast, err := s.broadcastTx(ctx, req) + if err != nil { + return nil, err } - return nil, lastErr + return s.waitForConfirmation(ctx, broadcast.Hash) } func validateSubmitRequest(req *Request) error { @@ -127,32 +132,154 @@ func validateSubmitRequest(req *Request) error { return nil } -func (s *DirectSubmitter) submitOnce(ctx context.Context, req *Request) (*Result, error) { - account, err := s.app.AccountInfo(ctx, s.signer.Address()) - if err != nil { - return nil, fmt.Errorf("query submission account: %w", err) +func (s *DirectSubmitter) broadcastTx(ctx context.Context, req *Request) (*TxStatus, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var lastErr error + for range maxSequenceRetryRounds { + account, err := s.nextAccountLocked(ctx) + if err != nil { + return nil, err + } + + txBytes, err := s.buildBlobTx(req, account) + if err != nil { + return nil, err + } + + broadcast, err := s.app.BroadcastTx(ctx, txBytes) + if err != nil { + if isSequenceMismatchText(err.Error()) { + s.recoverSequenceLocked(account, err.Error()) + lastErr = fmt.Errorf("%w: %w", errSequenceMismatch, err) + continue + } + return nil, fmt.Errorf("broadcast blob tx: %w", err) + } + if err := checkTxStatus("broadcast", broadcast); err != nil { + if errors.Is(err, errSequenceMismatch) { + s.recoverSequenceLocked(account, err.Error()) + lastErr = err + continue + } + return nil, err + } + + if broadcast.Hash != "" { + s.rememberPendingLocked(broadcast.Hash, account.Sequence) + } + s.nextSequence = account.Sequence + 1 + s.sequenceReady = true + return broadcast, nil } - if account == nil { - return nil, errors.New("query submission account: empty response") + + return nil, lastErr +} + +func (s *DirectSubmitter) nextAccountLocked(ctx context.Context) (*AccountInfo, error) { + if !s.sequenceReady { + account, err := s.app.AccountInfo(ctx, s.signer.Address()) + if err != nil { + return nil, fmt.Errorf("query submission account: %w", err) + } + if account == nil { + return nil, errors.New("query submission account: empty response") + } + + s.accountNumber = account.AccountNumber + s.nextSequence = account.Sequence + s.sequenceReady = true + if err := s.reconcilePendingLocked(ctx); err != nil { + return nil, err + } } - txBytes, err := s.buildBlobTx(req, account) - if err != nil { - return nil, err + return &AccountInfo{ + Address: s.signer.Address(), + AccountNumber: s.accountNumber, + Sequence: s.nextSequence, + }, nil +} + +func (s *DirectSubmitter) invalidateSequenceLocked() { + s.accountNumber = 0 + s.nextSequence = 0 + s.sequenceReady = false +} + +func (s *DirectSubmitter) startSubmission() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.maxInFlight > 0 && s.inFlight >= s.maxInFlight { + return errTooManyInFlight } + s.inFlight++ + return nil +} - broadcast, err := s.app.BroadcastTx(ctx, txBytes) - if err != nil { - if isSequenceMismatchText(err.Error()) { - return nil, fmt.Errorf("%w: %w", errSequenceMismatch, err) +func (s *DirectSubmitter) finishSubmission() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.inFlight > 0 { + s.inFlight-- + } +} + +func (s *DirectSubmitter) recoverSequenceLocked(account *AccountInfo, errText string) { + expected, ok := expectedSequenceFromMismatchText(errText) + if !ok { + s.invalidateSequenceLocked() + return + } + + s.accountNumber = account.AccountNumber + s.nextSequence = expected + s.sequenceReady = true +} + +func (s *DirectSubmitter) reconcilePendingLocked(ctx context.Context) error { + if len(s.pendingSequences) == 0 { + return nil + } + + nextSequence := s.nextSequence + for hash, sequence := range s.pendingSequences { + _, err := s.app.GetTx(ctx, hash) + if err == nil { + delete(s.pendingSequences, hash) + continue } - return nil, fmt.Errorf("broadcast blob tx: %w", err) + if isTxNotFound(err) { + if sequence >= nextSequence { + nextSequence = sequence + 1 + } + continue + } + return fmt.Errorf("reconcile pending blob tx %s: %w", hash, err) } - if err := checkTxStatus("broadcast", broadcast); err != nil { - return nil, err + + s.nextSequence = nextSequence + return nil +} + +func (s *DirectSubmitter) rememberPendingLocked(hash string, sequence uint64) { + if hash == "" { + return } + s.pendingSequences[hash] = sequence +} - return s.waitForConfirmation(ctx, broadcast.Hash) +func (s *DirectSubmitter) clearPending(hash string) { + if hash == "" { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + delete(s.pendingSequences, hash) } func (s *DirectSubmitter) buildBlobTx(req *Request, account *AccountInfo) ([]byte, error) { @@ -297,6 +424,7 @@ func (s *DirectSubmitter) waitForConfirmation(parent context.Context, hash strin for { tx, err := s.app.GetTx(ctx, hash) if err == nil { + s.clearPending(hash) if err := checkTxStatus("confirm", tx); err != nil { return nil, err } @@ -339,6 +467,29 @@ func isSequenceMismatchText(text string) bool { return strings.Contains(text, "account sequence mismatch") || strings.Contains(text, "incorrect account sequence") } +func expectedSequenceFromMismatchText(text string) (uint64, bool) { + lower := strings.ToLower(text) + idx := strings.Index(lower, "expected ") + if idx < 0 { + return 0, false + } + + start := idx + len("expected ") + end := start + for end < len(lower) && lower[end] >= '0' && lower[end] <= '9' { + end++ + } + if end == start { + return 0, false + } + + sequence, err := strconv.ParseUint(lower[start:end], 10, 64) + if err != nil { + return 0, false + } + return sequence, true +} + func isTxNotFound(err error) bool { return status.Code(err) == codes.NotFound } diff --git a/pkg/submit/direct_test.go b/pkg/submit/direct_test.go index 5944143..3cf658b 100644 --- a/pkg/submit/direct_test.go +++ b/pkg/submit/direct_test.go @@ -3,13 +3,19 @@ package submit import ( "context" "errors" + "fmt" + "slices" "strings" + "sync" "testing" "time" + txv1beta1 "github.com/evstack/apex/pkg/api/grpc/gen/cosmos/tx/v1beta1" "github.com/evstack/apex/pkg/types" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" ) type fakeAppClient struct { @@ -145,8 +151,8 @@ func TestDirectSubmitterRetriesSequenceMismatch(t *testing.T) { if result.Height != 77 { t.Fatalf("height = %d, want 77", result.Height) } - if client.accountCalls != 2 { - t.Fatalf("account calls = %d, want 2", client.accountCalls) + if client.accountCalls != 1 { + t.Fatalf("account calls = %d, want 1", client.accountCalls) } if client.broadcastCalls != 2 { t.Fatalf("broadcast calls = %d, want 2", client.broadcastCalls) @@ -269,6 +275,230 @@ func TestDirectSubmitterConfirmationTimeout(t *testing.T) { } } +func TestDirectSubmitterSpamReservesSequencesBeforeConfirmation(t *testing.T) { + t.Parallel() + + const submissions = 6 + + signer := mustSigner(t) + client := newSequenceSpamAppClient(signer.Address(), 7, 11, submissions) + + submitter, err := NewDirectSubmitter(client, signer, DirectConfig{ + ChainID: "mocha-4", + GasPrice: 0.002, + ConfirmationTimeout: 50 * time.Millisecond, + }) + if err != nil { + t.Fatalf("NewDirectSubmitter: %v", err) + } + submitter.pollInterval = time.Millisecond + + var ( + wg sync.WaitGroup + start = make(chan struct{}) + errs = make(chan error, submissions) + heights = make(chan uint64, submissions) + ) + + for range submissions { + wg.Add(1) + go func() { + defer wg.Done() + <-start + + result, submitErr := submitter.Submit(context.Background(), testRequest()) + if submitErr != nil { + errs <- submitErr + return + } + heights <- result.Height + }() + } + + close(start) + wg.Wait() + close(errs) + close(heights) + + if len(errs) > 0 { + t.Fatalf("unexpected submit error: %v", <-errs) + } + if client.accountCalls != 1 { + t.Fatalf("account calls = %d, want 1", client.accountCalls) + } + if client.broadcastCalls != submissions { + t.Fatalf("broadcast calls = %d, want %d", client.broadcastCalls, submissions) + } + + wantSequences := []uint64{11, 12, 13, 14, 15, 16} + if !slices.Equal(client.sequences, wantSequences) { + t.Fatalf("broadcast sequences = %v, want %v", client.sequences, wantSequences) + } + + if got := len(heights); got != submissions { + t.Fatalf("confirmed results = %d, want %d", got, submissions) + } +} + +func TestDirectSubmitterRecoversAfterRestartWithPendingSequences(t *testing.T) { + t.Parallel() + + signer := mustSigner(t) + client := newSequenceRecoveryAppClient(signer.Address(), 7, 11, 16) + + // Simulate a fresh process with no local sequence cache while the mempool + // still holds earlier pending transactions. + submitter, err := NewDirectSubmitter(client, signer, DirectConfig{ + ChainID: "mocha-4", + GasPrice: 0.002, + ConfirmationTimeout: 100 * time.Millisecond, + }) + if err != nil { + t.Fatalf("NewDirectSubmitter: %v", err) + } + submitter.pollInterval = time.Millisecond + + result, err := submitter.Submit(context.Background(), testRequest()) + if err != nil { + t.Fatalf("Submit: %v", err) + } + if result.Height != 116 { + t.Fatalf("height = %d, want 116", result.Height) + } + if client.accountCalls != 1 { + t.Fatalf("account calls = %d, want 1", client.accountCalls) + } + if !slices.Equal(client.attemptSequences, []uint64{11, 16}) { + t.Fatalf("attempt sequences = %v, want [11 16]", client.attemptSequences) + } + if client.lastSuccessfulSequence() != 16 { + t.Fatalf("successful sequence = %d, want 16", client.lastSuccessfulSequence()) + } +} + +func TestDirectSubmitterRecoversWhenCachedSequenceFallsBehindExternalWriter(t *testing.T) { + t.Parallel() + + signer := mustSigner(t) + client := newSequenceRecoveryAppClient(signer.Address(), 7, 16, 16) + client.afterSuccessNext = []uint64{19} + + submitter, err := NewDirectSubmitter(client, signer, DirectConfig{ + ChainID: "mocha-4", + GasPrice: 0.002, + ConfirmationTimeout: 100 * time.Millisecond, + }) + if err != nil { + t.Fatalf("NewDirectSubmitter: %v", err) + } + submitter.pollInterval = time.Millisecond + + first, err := submitter.Submit(context.Background(), testRequest()) + if err != nil { + t.Fatalf("first Submit: %v", err) + } + if first.Height != 116 { + t.Fatalf("first height = %d, want 116", first.Height) + } + + second, err := submitter.Submit(context.Background(), testRequest()) + if err != nil { + t.Fatalf("second Submit: %v", err) + } + if second.Height != 119 { + t.Fatalf("second height = %d, want 119", second.Height) + } + if client.accountCalls != 1 { + t.Fatalf("account calls = %d, want 1", client.accountCalls) + } + if !slices.Equal(client.attemptSequences, []uint64{16, 17, 19}) { + t.Fatalf("attempt sequences = %v, want [16 17 19]", client.attemptSequences) + } +} + +func TestDirectSubmitterReconcilesPersistedPendingSequenceBeforeBroadcast(t *testing.T) { + t.Parallel() + + signer := mustSigner(t) + client := &persistedPendingReconcileAppClient{ + address: signer.Address(), + accountNumber: 7, + committedSequence: 11, + pendingHash: "tx-11", + expectedBroadcastSequence: 12, + } + + submitter, err := NewDirectSubmitter(client, signer, DirectConfig{ + ChainID: "mocha-4", + GasPrice: 0.002, + ConfirmationTimeout: 100 * time.Millisecond, + }) + if err != nil { + t.Fatalf("NewDirectSubmitter: %v", err) + } + submitter.pollInterval = time.Millisecond + submitter.pendingSequences = map[string]uint64{"tx-11": 11} + + result, err := submitter.Submit(context.Background(), testRequest()) + if err != nil { + t.Fatalf("Submit: %v", err) + } + if result.Height != 112 { + t.Fatalf("height = %d, want 112", result.Height) + } + if !slices.Equal(client.getTxHashes, []string{"tx-11", "tx-12"}) { + t.Fatalf("GetTx hashes = %v, want [tx-11 tx-12]", client.getTxHashes) + } + if !slices.Equal(client.attemptSequences, []uint64{12}) { + t.Fatalf("attempt sequences = %v, want [12]", client.attemptSequences) + } +} + +func TestDirectSubmitterRejectsWhenMaxInFlightExceeded(t *testing.T) { + t.Parallel() + + signer := mustSigner(t) + client := newInFlightLimitAppClient(signer.Address(), 7, 1) + + submitter, err := NewDirectSubmitter(client, signer, DirectConfig{ + ChainID: "mocha-4", + GasPrice: 0.002, + ConfirmationTimeout: 100 * time.Millisecond, + }) + if err != nil { + t.Fatalf("NewDirectSubmitter: %v", err) + } + submitter.pollInterval = time.Millisecond + submitter.maxInFlight = 1 + + firstDone := make(chan error, 1) + go func() { + _, submitErr := submitter.Submit(context.Background(), testRequest()) + firstDone <- submitErr + }() + + client.waitForBroadcast(t) + + secondCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + _, err = submitter.Submit(secondCtx, testRequest()) + if err == nil { + t.Fatal("expected max in-flight error, got nil") + } + if !strings.Contains(err.Error(), "too many in-flight") { + t.Fatalf("error = %v, want max in-flight rejection", err) + } + if client.broadcastCount() != 1 { + t.Fatalf("broadcast calls = %d, want 1", client.broadcastCount()) + } + + close(client.release) + if err := <-firstDone; err != nil { + t.Fatalf("first Submit: %v", err) + } +} + func mustSigner(t *testing.T) *Signer { t.Helper() @@ -319,3 +549,380 @@ func queueTxStatus(statuses []*TxStatus, idx int) *TxStatus { } return statuses[idx] } + +type sequenceSpamAppClient struct { + address string + accountNumber uint64 + baseSequence uint64 + total int + + mu sync.Mutex + sequences []uint64 + hashes map[string]uint64 + accountCalls int + broadcastCalls int +} + +func newSequenceSpamAppClient(address string, accountNumber, baseSequence uint64, total int) *sequenceSpamAppClient { + return &sequenceSpamAppClient{ + address: address, + accountNumber: accountNumber, + baseSequence: baseSequence, + total: total, + hashes: make(map[string]uint64, total), + } +} + +func (c *sequenceSpamAppClient) AccountInfo(_ context.Context, _ string) (*AccountInfo, error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.accountCalls++ + return &AccountInfo{ + Address: c.address, + AccountNumber: c.accountNumber, + Sequence: c.baseSequence, + }, nil +} + +func (c *sequenceSpamAppClient) BroadcastTx(_ context.Context, txBytes []byte) (*TxStatus, error) { + sequence, err := decodeSequenceFromBlobTx(txBytes) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.broadcastCalls++ + expected := c.baseSequence + uint64(len(c.sequences)) + if sequence != expected { + return &TxStatus{ + Code: 32, + RawLog: fmt.Sprintf("account sequence mismatch, expected %d, got %d", expected, sequence), + }, nil + } + + hash := fmt.Sprintf("tx-%d", sequence) + c.sequences = append(c.sequences, sequence) + c.hashes[hash] = sequence + return &TxStatus{Hash: hash}, nil +} + +func (c *sequenceSpamAppClient) GetTx(_ context.Context, hash string) (*TxStatus, error) { + c.mu.Lock() + defer c.mu.Unlock() + + sequence, ok := c.hashes[hash] + if !ok || len(c.sequences) < c.total { + return nil, status.Error(codes.NotFound, "not found") + } + + return &TxStatus{ + Hash: hash, + Height: int64(100 + sequence), + }, nil +} + +func (*sequenceSpamAppClient) Close() error { + return nil +} + +func decodeSequenceFromBlobTx(raw []byte) (uint64, error) { + innerTx, err := decodeInnerTx(raw) + if err != nil { + return 0, err + } + + var txRaw txv1beta1.TxRaw + if err := proto.Unmarshal(innerTx, &txRaw); err != nil { + return 0, fmt.Errorf("unmarshal tx raw: %w", err) + } + + var authInfo txv1beta1.AuthInfo + if err := proto.Unmarshal(txRaw.GetAuthInfoBytes(), &authInfo); err != nil { + return 0, fmt.Errorf("unmarshal auth info: %w", err) + } + + signerInfos := authInfo.GetSignerInfos() + if len(signerInfos) != 1 { + return 0, fmt.Errorf("signer infos = %d, want 1", len(signerInfos)) + } + + return signerInfos[0].GetSequence(), nil +} + +func decodeInnerTx(raw []byte) ([]byte, error) { + data := raw + for len(data) > 0 { + num, typ, n := protowire.ConsumeTag(data) + if n < 0 { + return nil, errors.New("decode blob tx tag") + } + data = data[n:] + if typ != protowire.BytesType { + return nil, fmt.Errorf("unexpected wire type %d for field %d", typ, num) + } + + value, n := protowire.ConsumeBytes(data) + if n < 0 { + return nil, fmt.Errorf("decode blob tx field %d", num) + } + data = data[n:] + + if num == 1 { + return value, nil + } + } + + return nil, errors.New("blob tx inner tx missing") +} + +type sequenceRecoveryAppClient struct { + address string + accountNumber uint64 + committedSequence uint64 + nextAvailable uint64 + afterSuccessNext []uint64 + + mu sync.Mutex + accountCalls int + attemptSequences []uint64 + successHashes map[string]uint64 + successCount int +} + +func newSequenceRecoveryAppClient(address string, accountNumber, committedSequence, nextAvailable uint64) *sequenceRecoveryAppClient { + return &sequenceRecoveryAppClient{ + address: address, + accountNumber: accountNumber, + committedSequence: committedSequence, + nextAvailable: nextAvailable, + successHashes: make(map[string]uint64), + } +} + +func (c *sequenceRecoveryAppClient) AccountInfo(_ context.Context, _ string) (*AccountInfo, error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.accountCalls++ + return &AccountInfo{ + Address: c.address, + AccountNumber: c.accountNumber, + Sequence: c.committedSequence, + }, nil +} + +func (c *sequenceRecoveryAppClient) BroadcastTx(_ context.Context, txBytes []byte) (*TxStatus, error) { + sequence, err := decodeSequenceFromBlobTx(txBytes) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.attemptSequences = append(c.attemptSequences, sequence) + if sequence != c.nextAvailable { + return &TxStatus{ + Code: 32, + RawLog: fmt.Sprintf("account sequence mismatch, expected %d, got %d", c.nextAvailable, sequence), + }, nil + } + + hash := fmt.Sprintf("tx-%d", sequence) + c.successHashes[hash] = sequence + c.successCount++ + if c.successCount <= len(c.afterSuccessNext) { + c.nextAvailable = c.afterSuccessNext[c.successCount-1] + } else { + c.nextAvailable = sequence + 1 + } + + return &TxStatus{Hash: hash}, nil +} + +func (c *sequenceRecoveryAppClient) GetTx(_ context.Context, hash string) (*TxStatus, error) { + c.mu.Lock() + defer c.mu.Unlock() + + sequence, ok := c.successHashes[hash] + if !ok { + return nil, status.Error(codes.NotFound, "not found") + } + + return &TxStatus{ + Hash: hash, + Height: int64(100 + sequence), + }, nil +} + +func (c *sequenceRecoveryAppClient) lastSuccessfulSequence() uint64 { + c.mu.Lock() + defer c.mu.Unlock() + + var last uint64 + for _, sequence := range c.successHashes { + if sequence > last { + last = sequence + } + } + return last +} + +func (*sequenceRecoveryAppClient) Close() error { + return nil +} + +type persistedPendingReconcileAppClient struct { + address string + accountNumber uint64 + committedSequence uint64 + pendingHash string + expectedBroadcastSequence uint64 + + mu sync.Mutex + getTxHashes []string + attemptSequences []uint64 +} + +func (c *persistedPendingReconcileAppClient) AccountInfo(_ context.Context, _ string) (*AccountInfo, error) { + return &AccountInfo{ + Address: c.address, + AccountNumber: c.accountNumber, + Sequence: c.committedSequence, + }, nil +} + +func (c *persistedPendingReconcileAppClient) BroadcastTx(_ context.Context, txBytes []byte) (*TxStatus, error) { + sequence, err := decodeSequenceFromBlobTx(txBytes) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.attemptSequences = append(c.attemptSequences, sequence) + if sequence != c.expectedBroadcastSequence { + return &TxStatus{ + Code: 32, + RawLog: fmt.Sprintf("account sequence mismatch, expected %d, got %d", c.expectedBroadcastSequence, sequence), + }, nil + } + + return &TxStatus{Hash: fmt.Sprintf("tx-%d", sequence)}, nil +} + +func (c *persistedPendingReconcileAppClient) GetTx(_ context.Context, hash string) (*TxStatus, error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.getTxHashes = append(c.getTxHashes, hash) + switch hash { + case c.pendingHash: + return nil, status.Error(codes.NotFound, "not found") + case fmt.Sprintf("tx-%d", c.expectedBroadcastSequence): + return &TxStatus{ + Hash: hash, + Height: int64(100 + c.expectedBroadcastSequence), + }, nil + default: + return nil, status.Error(codes.NotFound, "not found") + } +} + +func (*persistedPendingReconcileAppClient) Close() error { + return nil +} + +type inFlightLimitAppClient struct { + address string + accountNumber uint64 + baseSequence uint64 + + release chan struct{} + broadcasted chan struct{} + + mu sync.Mutex + broadcastCalls int + hashes map[string]uint64 +} + +func newInFlightLimitAppClient(address string, accountNumber, baseSequence uint64) *inFlightLimitAppClient { + return &inFlightLimitAppClient{ + address: address, + accountNumber: accountNumber, + baseSequence: baseSequence, + release: make(chan struct{}), + broadcasted: make(chan struct{}, 1), + hashes: make(map[string]uint64), + } +} + +func (c *inFlightLimitAppClient) AccountInfo(_ context.Context, _ string) (*AccountInfo, error) { + return &AccountInfo{ + Address: c.address, + AccountNumber: c.accountNumber, + Sequence: c.baseSequence, + }, nil +} + +func (c *inFlightLimitAppClient) BroadcastTx(_ context.Context, txBytes []byte) (*TxStatus, error) { + sequence, err := decodeSequenceFromBlobTx(txBytes) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + hash := fmt.Sprintf("tx-%d", sequence) + c.broadcastCalls++ + c.hashes[hash] = sequence + select { + case c.broadcasted <- struct{}{}: + default: + } + return &TxStatus{Hash: hash}, nil +} + +func (c *inFlightLimitAppClient) GetTx(_ context.Context, hash string) (*TxStatus, error) { + c.mu.Lock() + sequence, ok := c.hashes[hash] + c.mu.Unlock() + if !ok { + return nil, status.Error(codes.NotFound, "not found") + } + + select { + case <-c.release: + return &TxStatus{ + Hash: hash, + Height: int64(100 + sequence), + }, nil + default: + return nil, status.Error(codes.NotFound, "not found") + } +} + +func (*inFlightLimitAppClient) Close() error { + return nil +} + +func (c *inFlightLimitAppClient) waitForBroadcast(t *testing.T) { + t.Helper() + + select { + case <-c.broadcasted: + case <-time.After(time.Second): + t.Fatal("timed out waiting for first broadcast") + } +} + +func (c *inFlightLimitAppClient) broadcastCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.broadcastCalls +}