diff --git a/queue/message.go b/queue/message.go index fea16f5..298cd00 100644 --- a/queue/message.go +++ b/queue/message.go @@ -2,6 +2,8 @@ package queue import ( "context" + "errors" + "fmt" "sync" "github.com/go-json-experiment/json" @@ -25,7 +27,7 @@ type Producer interface { Produce(ctx context.Context, rec *kgo.Record, promise func(*kgo.Record, error)) } -const topic = "kaiyan.chat" +const Topic = "kaiyan.chat" var recordPool sync.Pool @@ -37,7 +39,7 @@ func Send(ctx context.Context, cl Producer, msg Message, errs chan<- error) { } rec, _ := recordPool.Get().(*kgo.Record) if rec == nil { - rec = &kgo.Record{Topic: topic} + rec = &kgo.Record{Topic: Topic} } rec.Key = append(rec.Key[:0], msg.ID...) rec.Value = append(rec.Value[:0], b...) @@ -51,3 +53,23 @@ func Send(ctx context.Context, cl Producer, msg Message, errs chan<- error) { } }) } + +type Consumer interface { + PollFetches(ctx context.Context) kgo.Fetches +} + +func Recv(ctx context.Context, cl Consumer, onto []Message) ([]Message, error) { + f := cl.PollFetches(ctx) + var errs error + f.EachError(func(s string, i int32, err error) { + errs = errors.Join(errs, fmt.Errorf("partition %d: %w", i, err)) + }) + f.EachRecord(func(r *kgo.Record) { + var msg Message + if err := json.Unmarshal(r.Value, &msg); err != nil { + errs = errors.Join(errs, fmt.Errorf("partition %d: %w", r.Partition, err)) + } + onto = append(onto, msg) + }) + return onto, errs +} diff --git a/queue/message_test.go b/queue/message_test.go index 0fe233f..4cc9ae0 100644 --- a/queue/message_test.go +++ b/queue/message_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" - "git.sunturtle.xyz/zephyr/kaiyan/queue" "github.com/go-json-experiment/json" "github.com/twmb/franz-go/pkg/kgo" + + "git.sunturtle.xyz/zephyr/kaiyan/queue" ) type spyProducer struct { @@ -48,3 +49,57 @@ func TestSend(t *testing.T) { t.Errorf("message did not round-trip:\nwant %+v\ngot %+v", msg, got) } } + +type spyConsumer struct { + get kgo.Fetches +} + +func (c *spyConsumer) PollFetches(ctx context.Context) kgo.Fetches { + return c.get +} + +func TestRecv(t *testing.T) { + msg := queue.Message{ + ID: "bocchi", + Channel: "kessoku", + Sender: queue.Sender(make([]byte, 16), "kessoku", "ryō"), + Text: "bocchi the rock!", + } + val, err := json.Marshal(&msg) + if err != nil { + t.Fatal(err) + } + p := spyConsumer{ + get: kgo.Fetches{ + { + Topics: []kgo.FetchTopic{ + { + Topic: queue.Topic, + Partitions: []kgo.FetchPartition{ + { + Partition: 1, + Records: []*kgo.Record{ + { + Key: []byte("bocchi"), + Value: val, + Partition: 1, + }, + }, + }, + }, + }, + }, + }, + }, + } + got, err := queue.Recv(context.Background(), &p, nil) + if err != nil { + t.Error(err) + } + if len(got) != 1 { + t.Fatalf("wrong number of messages: want 1, got %d", len(got)) + } + if got[0] != msg { + t.Errorf("message did not round-trip:\nwant %+v\ngot %+v", msg, got[0]) + } +}