WIP:PartitionState -- COPY
authorKai Moritz <kai@juplo.de>
Sat, 2 Nov 2024 16:15:25 +0000 (17:15 +0100)
committerKai Moritz <kai@juplo.de>
Sat, 2 Nov 2024 16:15:25 +0000 (17:15 +0100)
src/main/java/de/juplo/kafka/PartitionState.java [new file with mode: 0644]

diff --git a/src/main/java/de/juplo/kafka/PartitionState.java b/src/main/java/de/juplo/kafka/PartitionState.java
new file mode 100644 (file)
index 0000000..0e66528
--- /dev/null
@@ -0,0 +1,446 @@
+package de.juplo.kafka;
+
+import lombok.extern.slf4j.Slf4j;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.WakeupException;
+
+import java.time.Duration;
+import java.util.*;
+import java.util.concurrent.Phaser;
+
+
+@Slf4j
+public class ExampleConsumer implements Runnable, ConsumerRebalanceListener
+{
+  private final String id;
+  private final String topic;
+  private final Consumer<String, String> consumer;
+  private final Thread workerThread;
+
+  private final String stateTopic;
+  Producer<String, String> producer;
+
+  private volatile boolean running = false;
+  private final Phaser phaser = new Phaser(1);
+  private final Set<TopicPartition> assignedPartitions = new HashSet<>();
+  private volatile State[] partitionStates;
+  private Map<String, Long>[] restoredState;
+  private CounterState[] counterState;
+  private volatile long[] stateEndOffsets;
+  private volatile int[] seen;
+  private volatile int[] acked;
+  private volatile boolean[] done;
+  private long consumed = 0;
+
+
+  public ExampleConsumer(
+    String clientId,
+    String topic,
+    Consumer<String, String> consumer,
+    String stateTopic,
+    Producer<String, String> producer)
+  {
+    this.id = clientId;
+    this.topic = topic;
+    this.consumer = consumer;
+    this.stateTopic = stateTopic;
+    this.producer = producer;
+
+    workerThread = new Thread(this, "ExampleConsumer Worker-Thread");
+    workerThread.start();
+  }
+
+
+  @Override
+  public void run()
+  {
+    try
+    {
+      log.info("{} - Fetching PartitionInfo for topic {}", id, topic);
+      int numPartitions = consumer.partitionsFor(topic).size();
+      log.info("{} - Topic {} has {} partitions", id, topic, numPartitions);
+      partitionStates = new State[numPartitions];
+      for (int i=0; i<numPartitions; i++)
+      {
+        partitionStates[i] = State.UNASSIGNED;
+      }
+      restoredState = new Map[numPartitions];
+      counterState = new CounterState[numPartitions];
+      stateEndOffsets = new long[numPartitions];
+      seen = new int[numPartitions];
+      acked = new int[numPartitions];
+      done = new boolean[numPartitions];
+
+      log.info("{} - Subscribing to topic {}", id, topic);
+      consumer.subscribe(Arrays.asList(topic, stateTopic), this);
+      running = true;
+
+      while (running)
+      {
+        ConsumerRecords<String, String> records =
+            consumer.poll(Duration.ofSeconds(1));
+
+        int phase = phaser.getPhase();
+
+        assignedPartitions
+          .forEach(partition ->
+          {
+            seen[partition.partition()] = 0;
+            acked[partition.partition()] = 0;
+            done[partition.partition()] = false;
+          });
+
+        log.info("{} - Received {} messages in phase {}", id, records.count(), phase);
+        records
+          .partitions()
+          .forEach(partition ->
+          {
+            for (ConsumerRecord<String, String> record : records.records(partition))
+            {
+              handleRecord(
+                record.topic(),
+                record.partition(),
+                record.offset(),
+                record.key(),
+                record.value());
+            }
+
+            done[partition.partition()] = true;
+          });
+
+        assignedPartitions
+          .forEach(partition ->
+          {
+            if (seen[partition.partition()] == 0)
+            {
+              int arrivedPhase = phaser.arrive();
+              log.debug("{} - Received no records for partition {} in phase {}", id, partition, arrivedPhase);
+            }
+          });
+
+        int arrivedPhase = phaser.arriveAndAwaitAdvance();
+        log.info("{} - Phase {} is done! Next phase: {}", id, phase, arrivedPhase);
+      }
+    }
+    catch(WakeupException e)
+    {
+      log.info("{} - Consumer was signaled to finish its work", id);
+    }
+    catch(Exception e)
+    {
+      log.error("{} - Unexpected error: {}, unsubscribing!", id, e.toString());
+      consumer.unsubscribe();
+    }
+    finally
+    {
+      log.info("{}: Consumed {} messages in total, exiting!", id, consumed);
+    }
+  }
+
+  private void handleRecord(
+    String topic,
+    Integer partition,
+    Long offset,
+    String key,
+    String value)
+  {
+    consumed++;
+    log.info("{} - {}: {}/{} - {}={}", id, offset, topic, partition, key, value);
+
+    if (topic.equals(this.topic))
+    {
+      handleMessage(partition, key);
+    }
+    else
+    {
+      handleState(partition, offset, key, value);
+    }
+  }
+
+  private synchronized void handleState(
+    int partition,
+    long offset,
+    String key,
+    String value)
+  {
+    restoredState[partition].put(key, Long.parseLong(value));
+    if (offset + 1 == stateEndOffsets[partition])
+    {
+      log.info("{} - Restoring of state for partition {} done!", id, partition);
+      stateAssigned(partition);
+    }
+    else
+    {
+      log.debug(
+        "{} - Restored state up to offset {}, end-offset: {}, state: {}={}",
+        id,
+        offset,
+        stateEndOffsets[partition],
+        key,
+        value);
+    }
+  }
+
+  private void handleMessage(
+    Integer partition,
+    String key)
+  {
+    Long counter = computeCount(partition, key);
+    log.info("{} - current value for counter {}: {}", id, key, counter);
+    sendCounterState(partition, key, counter);
+  }
+
+  private synchronized Long computeCount(int partition, String key)
+  {
+    return counterState[partition].addToCounter(key);
+  }
+
+  public Map<Integer, Map<String, Long>> getCounterState()
+  {
+    Map<Integer, Map<String, Long>> result = new HashMap<>(assignedPartitions.size());
+    assignedPartitions.forEach(tp -> result.put(tp.partition(), counterState[tp.partition()].getCounterState()));
+    return result;
+  }
+
+  void sendCounterState(int partition, String key, Long counter)
+  {
+    seen[partition]++;
+
+    final long time = System.currentTimeMillis();
+
+    final ProducerRecord<String, String> record = new ProducerRecord<>(
+        stateTopic,        // Topic
+        key,               // Key
+        counter.toString() // Value
+    );
+
+    producer.send(record, (metadata, e) ->
+    {
+      long now = System.currentTimeMillis();
+      if (e == null)
+      {
+        // HANDLE SUCCESS
+        log.debug(
+            "{} - Sent message {}={}, partition={}:{}, timestamp={}, latency={}ms",
+            id,
+            record.key(),
+            record.value(),
+            metadata.partition(),
+            metadata.offset(),
+            metadata.timestamp(),
+            now - time
+        );
+      }
+      else
+      {
+        // HANDLE ERROR
+        log.error(
+            "{} - ERROR for message {}={}, timestamp={}, latency={}ms: {}",
+            id,
+            record.key(),
+            record.value(),
+            metadata == null ? -1 : metadata.timestamp(),
+            now - time,
+            e.toString()
+        );
+      }
+
+      acked[partition]++;
+      if (done[partition] && !(acked[partition] < seen[partition]))
+      {
+        int arrivedPhase = phaser.arrive();
+        log.debug(
+            "{} - Arrived at phase {} for partition {}, seen={}, acked={}",
+            id,
+            arrivedPhase,
+            partition,
+            seen[partition],
+            acked[partition]);
+      }
+      else
+      {
+        log.debug(
+            "{} - Still in phase {} for partition {}, seen={}, acked={}",
+            id,
+            phaser.getPhase(),
+            partition,
+            seen[partition],
+            acked[partition]);
+      }
+    });
+
+    long now = System.currentTimeMillis();
+    log.trace(
+        "{} - Queued message {}={}, latency={}ms",
+        id,
+        record.key(),
+        record.value(),
+        now - time
+    );
+  }
+
+  @Override
+  public void onPartitionsAssigned(Collection<TopicPartition> partitions)
+  {
+    partitions
+      .stream()
+      .filter(partition -> partition.topic().equals(topic))
+      .forEach(partition -> restoreAndAssign(partition.partition()));
+  }
+
+  @Override
+  public synchronized void onPartitionsRevoked(Collection<TopicPartition> partitions)
+  {
+    partitions
+      .stream()
+      .filter(partition -> partition.topic().equals(topic))
+      .forEach(partition -> revoke(partition.partition()));
+  }
+
+  private void restoreAndAssign(int partition)
+  {
+    TopicPartition statePartition = new TopicPartition(this.stateTopic, partition);
+
+    long stateEndOffset = consumer
+      .endOffsets(List.of(statePartition))
+      .get(statePartition)
+      .longValue();
+
+    long stateBeginningOffset = consumer
+      .beginningOffsets(List.of(statePartition))
+      .get(statePartition);
+
+    log.info(
+      "{} - Found beginning-offset {} and end-offset {} for state partition {}",
+      id,
+      stateBeginningOffset,
+      stateEndOffset,
+      partition);
+
+    if (stateBeginningOffset < stateEndOffset)
+    {
+      stateRestoring(partition, stateBeginningOffset, stateEndOffset);
+    }
+    else
+    {
+      log.info("{} - No state available for partition {}", id, partition);
+      restoredState[partition] = new HashMap<>();
+      stateAssigned(partition);
+    }
+  }
+
+  private void revoke(int partition)
+  {
+    State partitionState = partitionStates[partition];
+    switch (partitionState)
+    {
+      case RESTORING, ASSIGNED -> stateUnassigned(partition);
+      case UNASSIGNED -> log.warn("{} - partition {} in state {} was revoked!", id, partition, partitionState);
+    }
+  }
+
+  private void stateRestoring(int partition, long stateBeginningOffset, long stateEndOffset)
+  {
+    log.info(
+      "{} - Changing partition-state for {}: {} -> RESTORING",
+      id,
+      partition,
+      partitionStates[partition]);
+    partitionStates[partition] = State.RESTORING;
+
+    TopicPartition messagePartition = new TopicPartition(this.topic, partition);
+    log.info("{} - Pausing message partition {}", id, messagePartition);
+    consumer.pause(List.of(messagePartition));
+
+    TopicPartition statePartition = new TopicPartition(this.stateTopic, partition);
+    log.info(
+      "{} - Seeking to first offset {} for state partition {}",
+      id,
+      stateBeginningOffset,
+      statePartition);
+    consumer.seek(statePartition, stateBeginningOffset);
+    stateEndOffsets[partition] = stateEndOffset;
+    restoredState[partition] = new HashMap<>();
+    log.info("{} - Resuming state partition {}", id, statePartition);
+    consumer.resume(List.of(statePartition));
+  }
+
+  private void stateAssigned(int partition)
+  {
+    log.info(
+      "{} - State-change for partition {}: {} -> ASSIGNED",
+      id,
+      partition,
+      partitionStates[partition]);
+
+    partitionStates[partition] = State.ASSIGNED;
+
+    TopicPartition statePartition = new TopicPartition(stateTopic, partition);
+    log.info("{} - Pausing state partition {}...", id, statePartition);
+    consumer.pause(List.of(statePartition));
+    counterState[partition] = new CounterState(restoredState[partition]);
+    restoredState[partition] = null;
+
+    TopicPartition messagePartition = new TopicPartition(topic, partition);
+    log.info("{} - Adding partition {} to the assigned partitions", id, messagePartition);
+    assignedPartitions.add(messagePartition);
+    phaser.register();
+    log.info(
+      "{} - Registered new partie for newly assigned partition {}. New total number of parties: {}",
+      id,
+      messagePartition,
+      phaser.getRegisteredParties());
+    log.info("{} - Resuming message partition {}...", id, messagePartition);
+    consumer.resume(List.of(messagePartition));
+  }
+
+  private void stateUnassigned(int partition)
+  {
+    State oldPartitionState = partitionStates[partition];
+
+    log.info(
+      "{} - State-change for partition {}: {} -> UNASSIGNED",
+      id,
+      partition,
+      oldPartitionState);
+
+    partitionStates[partition] = State.UNASSIGNED;
+
+    if (oldPartitionState == State.ASSIGNED)
+    {
+      TopicPartition messagePartition = new TopicPartition(topic, partition);
+      log.info("{} - Revoking partition {}", id, messagePartition);
+      assignedPartitions.remove(messagePartition);
+      counterState[partition] = null;
+
+      phaser.arriveAndDeregister();
+      log.info(
+        "{} - Deregistered partie for revoked partition {}. New total number of parties: {}",
+        id,
+        messagePartition,
+        phaser.getRegisteredParties());
+    }
+  }
+
+
+  public void shutdown() throws InterruptedException
+  {
+    log.info("{} joining the worker-thread...", id);
+    running = false;
+    consumer.wakeup();
+    workerThread.join();
+  }
+
+  enum State
+  {
+    UNASSIGNED,
+    RESTORING,
+    ASSIGNED
+  }
+}