package queue_test

import (
	"context"
	"testing"

	"github.com/go-json-experiment/json"
	"github.com/twmb/franz-go/pkg/kgo"

	"git.sunturtle.xyz/zephyr/kaiyan/queue"
)

type spyProducer struct {
	got []kgo.Record
}

func (p *spyProducer) Produce(ctx context.Context, rec *kgo.Record, promise func(*kgo.Record, error)) {
	p.got = append(p.got, *rec)
	promise(rec, nil)
}

func TestSend(t *testing.T) {
	msg := queue.Message{
		ID:        "bocchi",
		Channel:   "kessoku",
		Sender:    queue.Sender(make([]byte, 16), "kessoku", "ryō"),
		Timestamp: 1,
		Text:      "bocchi the rock!",
	}
	var q spyProducer
	errs := make(chan error, 1)
	queue.Send(context.Background(), &q, msg, errs)
	select {
	case err := <-errs:
		t.Error(err)
	default: // do nothing
	}
	if len(q.got) != 1 {
		t.Fatalf("wrong number of records produced: %d", len(q.got))
	}
	rec := &q.got[0]
	if string(rec.Key) != msg.ID {
		t.Errorf("record has wrong key: want %q, got %q", msg.ID, rec.Key)
	}
	var got queue.Message
	if err := json.Unmarshal(rec.Value, &got); err != nil {
		t.Error(err)
	}
	if got != msg {
		t.Errorf("message did not round-trip:\nwant %+v\ngot  %+v", msg, got)
	}
}

type spyConsumer struct {
	get kgo.Fetches
}

func (c *spyConsumer) PollFetches(ctx context.Context) kgo.Fetches {
	return c.get
}

func TestRecv(t *testing.T) {
	msg := queue.Message{
		ID:      "bocchi",
		Channel: "kessoku",
		Sender:  queue.Sender(make([]byte, 16), "kessoku", "ryō"),
		Text:    "bocchi the rock!",
	}
	val, err := json.Marshal(&msg)
	if err != nil {
		t.Fatal(err)
	}
	p := spyConsumer{
		get: kgo.Fetches{
			{
				Topics: []kgo.FetchTopic{
					{
						Topic: queue.Topic,
						Partitions: []kgo.FetchPartition{
							{
								Partition: 1,
								Records: []*kgo.Record{
									{
										Key:       []byte("bocchi"),
										Value:     val,
										Partition: 1,
									},
								},
							},
						},
					},
				},
			},
		},
	}
	got, err := queue.Recv(context.Background(), &p, nil)
	if err != nil {
		t.Error(err)
	}
	if len(got) != 1 {
		t.Fatalf("wrong number of messages: want 1, got %d", len(got))
	}
	if got[0] != msg {
		t.Errorf("message did not round-trip:\nwant %+v\ngot  %+v", msg, got[0])
	}
}