import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.consumer.RangeAssignor;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.serialization.StringDeserializer;
+import org.apache.kafka.common.serialization.StringSerializer;
 import org.springframework.boot.context.properties.EnableConfigurationProperties;
 import org.springframework.context.ConfigurableApplicationContext;
 import org.springframework.context.annotation.Bean;
   @Bean
   public ExampleConsumer exampleConsumer(
     Consumer<String, String> kafkaConsumer,
+    Producer<String, String> kafkaProducer,
     ApplicationProperties properties,
     ConfigurableApplicationContext applicationContext)
   {
         properties.getClientId(),
         properties.getConsumerProperties().getTopic(),
         kafkaConsumer,
+        properties.getProducerProperties().getTopic(),
+        kafkaProducer,
         () -> applicationContext.close());
   }
 
       props.put("auto.commit.interval", properties.getConsumerProperties().getAutoCommitInterval());
     }
     props.put("metadata.max.age.ms", 5000); //  5 Sekunden
+    props.put("partition.assignment.strategy", RangeAssignor.class.getName());
     props.put("key.deserializer", StringDeserializer.class.getName());
     props.put("value.deserializer", StringDeserializer.class.getName());
 
     return new KafkaConsumer<>(props);
   }
+
+  @Bean
+  public KafkaProducer<String, String> kafkaProducer(ApplicationProperties properties)
+  {
+    Properties props = new Properties();
+    props.put("bootstrap.servers", properties.getBootstrapServer());
+    props.put("client.id", properties.getClientId());
+    props.put("acks", properties.getProducerProperties().getAcks());
+    props.put("batch.size", properties.getProducerProperties().getBatchSize());
+    props.put("metadata.max.age.ms",   5000); //  5 Sekunden
+    props.put("delivery.timeout.ms", 20000); // 20 Sekunden
+    props.put("request.timeout.ms",  10000); // 10 Sekunden
+    props.put("linger.ms", properties.getProducerProperties().getLingerMs());
+    props.put("compression.type", properties.getProducerProperties().getCompressionType());
+    props.put("key.serializer", StringSerializer.class.getName());
+    props.put("value.serializer", StringSerializer.class.getName());
+
+    return new KafkaProducer<>(props);
+  }
 }
 
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.WakeupException;
 
 import java.time.Duration;
 import java.util.*;
+import java.util.concurrent.Phaser;
 
 
 @Slf4j
   private final Thread workerThread;
   private final Runnable closeCallback;
 
+  private final String stateTopic;
+  private final Producer<String, String> producer;
+
+  private final Phaser phaser = new Phaser(1);
   private final Set<TopicPartition> assignedPartitions = new HashSet<>();
-  private CounterState<String>[] counterState;
+  private volatile State[] partitionStates;
+  private Map<String, Long>[] restoredState;
+  private CounterState[] counterState;
+  private volatile long[] stateEndOffsets;
+  private volatile int[] seen;
+  private volatile int[] acked;
+  private volatile boolean[] done;
   private long consumed = 0;
 
 
     String clientId,
     String topic,
     Consumer<String, String> consumer,
+    String stateTopic,
+    Producer<String, String> producer,
     Runnable closeCallback)
   {
     this.id = clientId;
     this.topic = topic;
     this.consumer = consumer;
+    this.stateTopic = stateTopic;
+    this.producer = producer;
 
     workerThread = new Thread(this, "ExampleConsumer Worker-Thread");
     workerThread.start();
       log.info("{} - Fetching PartitionInfo for topic {}", id, topic);
       int numPartitions = consumer.partitionsFor(topic).size();
       log.info("{} - Topic {} has {} partitions", id, topic, numPartitions);
+      partitionStates = new State[numPartitions];
+      for (int i=0; i<numPartitions; i++)
+      {
+        partitionStates[i] = State.UNASSIGNED;
+      }
+      restoredState = new Map[numPartitions];
       counterState = new CounterState[numPartitions];
+      stateEndOffsets = new long[numPartitions];
+      seen = new int[numPartitions];
+      acked = new int[numPartitions];
+      done = new boolean[numPartitions];
 
       log.info("{} - Subscribing to topic {}", id, topic);
-      consumer.subscribe(Arrays.asList(topic), this);
+      consumer.subscribe(Arrays.asList(topic, stateTopic), this);
 
       while (true)
       {
         ConsumerRecords<String, String> records = consumer.poll(Duration.ofSeconds(1));
 
-        log.info("{} - Received {} messages", id, records.count());
-        for (ConsumerRecord<String, String> record : records)
-        {
-          handleRecord(
-            record.topic(),
-            record.partition(),
-            record.offset(),
-            record.key(),
-            record.value());
-        }
+        int phase = phaser.getPhase();
+
+        assignedPartitions
+          .forEach(partition ->
+          {
+            seen[partition.partition()] = 0;
+            acked[partition.partition()] = 0;
+            done[partition.partition()] = false;
+          });
+
+        log.info("{} - Received {} messages in phase {}", id, records.count(), phase);
+        records
+          .partitions()
+          .forEach(partition ->
+          {
+            for (ConsumerRecord<String, String> record : records.records(partition))
+            {
+              handleRecord(
+                record.topic(),
+                record.partition(),
+                record.offset(),
+                record.key(),
+                record.value());
+            }
+
+            checkRestoreProgress(partition);
+
+            done[partition.partition()] = true;
+          });
+
+        assignedPartitions
+          .forEach(partition ->
+          {
+            if (seen[partition.partition()] == 0)
+            {
+              int arrivedPhase = phaser.arrive();
+              log.debug("{} - Received no records for partition {} in phase {}", id, partition, arrivedPhase);
+            }
+          });
+
+        int arrivedPhase = phaser.arriveAndAwaitAdvance();
+        log.info("{} - Phase {} is done! Next phase: {}", id, phase, arrivedPhase);
       }
     }
     catch(WakeupException e)
     consumed++;
     log.info("{} - partition={}-{}, offset={}: {}={}", id, topic, partition, offset, key, value);
 
+    if (topic.equals(this.topic))
+    {
+      handleMessage(partition, key);
+    }
+    else
+    {
+      handleState(partition, key, value);
+    }
+  }
+
+  private void checkRestoreProgress(TopicPartition topicPartition)
+  {
+    int partition = topicPartition.partition();
+
+    if (partitionStates[partition] == State.RESTORING)
+    {
+      long consumerPosition = consumer.position(topicPartition);
+
+      if (consumerPosition + 1 >= stateEndOffsets[partition])
+      {
+        log.info(
+          "{} - Position of consumer is {}. Restoring of state for partition {} done!",
+          id,
+          consumerPosition,
+          topicPartition);
+        stateAssigned(partition);
+      }
+      else
+      {
+        log.debug(
+          "{} - Restored state up to offset {}, end-offset: {}",
+          id,
+          consumerPosition,
+          stateEndOffsets[partition]);
+      }
+    }
+  }
+
+  private synchronized void handleState(
+    int partition,
+    String key,
+    String value)
+  {
+    restoredState[partition].put(key, Long.parseLong(value));
+  }
+
+  private void handleMessage(
+    Integer partition,
+    String key)
+  {
     Long counter = computeCount(partition, key);
     log.info("{} - current value for counter {}: {}", id, key, counter);
+    sendCounterState(partition, key, counter);
   }
 
   private synchronized Long computeCount(int partition, String key)
     return result;
   }
 
+  void sendCounterState(int partition, String key, Long counter)
+  {
+    seen[partition]++;
+
+    final long time = System.currentTimeMillis();
+
+    final ProducerRecord<String, String> record = new ProducerRecord<>(
+      stateTopic,        // Topic
+      key,               // Key
+      counter.toString() // Value
+    );
+
+    producer.send(record, (metadata, e) ->
+    {
+      long now = System.currentTimeMillis();
+      if (e == null)
+      {
+        // HANDLE SUCCESS
+        log.debug(
+          "{} - Sent message {}={}, partition={}:{}, timestamp={}, latency={}ms",
+          id,
+          record.key(),
+          record.value(),
+          metadata.partition(),
+          metadata.offset(),
+          metadata.timestamp(),
+          now - time
+        );
+      }
+      else
+      {
+        // HANDLE ERROR
+        log.error(
+          "{} - ERROR for message {}={}, timestamp={}, latency={}ms: {}",
+          id,
+          record.key(),
+          record.value(),
+          metadata == null ? -1 : metadata.timestamp(),
+          now - time,
+          e.toString()
+        );
+      }
+
+      acked[partition]++;
+      if (done[partition] && !(acked[partition] < seen[partition]))
+      {
+        int arrivedPhase = phaser.arrive();
+        log.debug(
+          "{} - Arrived at phase {} for partition {}, seen={}, acked={}",
+          id,
+          arrivedPhase,
+          partition,
+          seen[partition],
+          acked[partition]);
+      }
+      else
+      {
+        log.debug(
+          "{} - Still in phase {} for partition {}, seen={}, acked={}",
+          id,
+          phaser.getPhase(),
+          partition,
+          seen[partition],
+          acked[partition]);
+      }
+    });
+
+    long now = System.currentTimeMillis();
+    log.trace(
+      "{} - Queued message {}={}, latency={}ms",
+      id,
+      record.key(),
+      record.value(),
+      now - time
+    );
+  }
 
   @Override
   public void onPartitionsAssigned(Collection<TopicPartition> partitions)
     partitions
       .stream()
       .filter(partition -> partition.topic().equals(topic))
-      .forEach(partition ->
-      {
-        assignedPartitions.add(partition);
-        counterState[partition.partition()] = new CounterState<>(new HashMap<>());
-      });
+      .forEach(partition -> restoreAndAssign(partition.partition()));
   }
 
   @Override
     partitions
       .stream()
       .filter(partition -> partition.topic().equals(topic))
-      .forEach(partition ->
-      {
-        assignedPartitions.remove(partition);
-        counterState[partition.partition()] = null;
-      });
+      .forEach(partition -> revoke(partition.partition()));
+  }
+
+  private void restoreAndAssign(int partition)
+  {
+    TopicPartition statePartition = new TopicPartition(this.stateTopic, partition);
+
+    long stateEndOffset = consumer
+      .endOffsets(List.of(statePartition))
+      .get(statePartition)
+      .longValue();
+
+    long stateBeginningOffset = consumer
+      .beginningOffsets(List.of(statePartition))
+      .get(statePartition);
+
+    log.info(
+      "{} - Found beginning-offset {} and end-offset {} for state partition {}",
+      id,
+      stateBeginningOffset,
+      stateEndOffset,
+      partition);
+
+    if (stateBeginningOffset < stateEndOffset)
+    {
+      stateRestoring(partition, stateBeginningOffset, stateEndOffset);
+    }
+    else
+    {
+      log.info("{} - No state available for partition {}", id, partition);
+      restoredState[partition] = new HashMap<>();
+      stateAssigned(partition);
+    }
+  }
+
+  private void revoke(int partition)
+  {
+    State partitionState = partitionStates[partition];
+    switch (partitionState)
+    {
+      case RESTORING, ASSIGNED -> stateUnassigned(partition);
+      case UNASSIGNED -> log.warn("{} - partition {} in state {} was revoked!", id, partition, partitionState);
+    }
+  }
+
+  private void stateRestoring(int partition, long stateBeginningOffset, long stateEndOffset)
+  {
+    log.info(
+      "{} - Changing partition-state for {}: {} -> RESTORING",
+      id,
+      partition,
+      partitionStates[partition]);
+    partitionStates[partition] = State.RESTORING;
+
+    TopicPartition messagePartition = new TopicPartition(this.topic, partition);
+    log.info("{} - Pausing message partition {}", id, messagePartition);
+    consumer.pause(List.of(messagePartition));
+
+    TopicPartition statePartition = new TopicPartition(this.stateTopic, partition);
+    log.info(
+      "{} - Seeking to first offset {} for state partition {}",
+      id,
+      stateBeginningOffset,
+      statePartition);
+    consumer.seek(statePartition, stateBeginningOffset);
+    stateEndOffsets[partition] = stateEndOffset;
+    restoredState[partition] = new HashMap<>();
+    log.info("{} - Resuming state partition {}", id, statePartition);
+    consumer.resume(List.of(statePartition));
+  }
+
+  private void stateAssigned(int partition)
+  {
+    log.info(
+      "{} - State-change for partition {}: {} -> ASSIGNED",
+      id,
+      partition,
+      partitionStates[partition]);
+
+    partitionStates[partition] = State.ASSIGNED;
+
+    TopicPartition statePartition = new TopicPartition(stateTopic, partition);
+    log.info("{} - Pausing state partition {}...", id, statePartition);
+    consumer.pause(List.of(statePartition));
+    counterState[partition] = new CounterState(restoredState[partition]);
+    restoredState[partition] = null;
+
+    TopicPartition messagePartition = new TopicPartition(topic, partition);
+    log.info("{} - Adding partition {} to the assigned partitions", id, messagePartition);
+    assignedPartitions.add(messagePartition);
+    phaser.register();
+    log.info(
+      "{} - Registered new party for newly assigned partition {}. New total number of parties: {}",
+      id,
+      messagePartition,
+      phaser.getRegisteredParties());
+    log.info("{} - Resuming message partition {}...", id, messagePartition);
+    consumer.resume(List.of(messagePartition));
+  }
+
+  private void stateUnassigned(int partition)
+  {
+    State oldPartitionState = partitionStates[partition];
+
+    log.info(
+      "{} - State-change for partition {}: {} -> UNASSIGNED",
+      id,
+      partition,
+      oldPartitionState);
+
+    partitionStates[partition] = State.UNASSIGNED;
+
+    if (oldPartitionState == State.ASSIGNED)
+    {
+      TopicPartition messagePartition = new TopicPartition(topic, partition);
+      log.info("{} - Revoking partition {}", id, messagePartition);
+      assignedPartitions.remove(messagePartition);
+      counterState[partition] = null;
+
+      phaser.arriveAndDeregister();
+      log.info(
+        "{} - Deregistered party for revoked partition {}. New total number of parties: {}",
+        id,
+        messagePartition,
+        phaser.getRegisteredParties());
+    }
   }
 
 
     log.info("{} - Joining the worker thread", id);
     workerThread.join();
   }
+
+  enum State
+  {
+    UNASSIGNED,
+    RESTORING,
+    ASSIGNED
+  }
 }