From 9efcd6d0bb8b2391f1a06a4735ea36e4ba10395a Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Mon, 9 Feb 2026 08:51:03 +0100 Subject: [PATCH 01/22] refactor: use Optional --- .../v1/api/collections/CollectionHandle.java | 7 ++++++- .../collections/CollectionHandleAsync.java | 3 ++- .../collections/CollectionHandleDefaults.java | 17 ++++++++++++----- .../collections/data/DeleteManyRequest.java | 4 ++-- .../collections/data/InsertManyRequest.java | 19 ++++++++++--------- .../collections/data/InsertObjectRequest.java | 4 ++-- .../api/collections/data/ObjectReference.java | 4 ++++ .../data/ReferenceAddManyRequest.java | 4 ++++ .../data/ReplaceObjectRequest.java | 4 ++-- .../collections/data/UpdateObjectRequest.java | 4 ++-- .../collections/query/ConsistencyLevel.java | 4 ++++ .../api/collections/query/QueryRequest.java | 5 +++-- .../api/collections/query/QueryResponse.java | 14 +------------- .../CollectionHandleDefaultsTest.java | 14 +++++++------- .../api/collections/CollectionHandleTest.java | 4 ++-- 15 files changed, 63 insertions(+), 48 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java index 7af8ed549..7bbc94c19 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -1,9 +1,11 @@ package io.weaviate.client6.v1.api.collections; import java.util.Collection; +import java.util.Optional; import java.util.function.Function; import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClient; +import io.weaviate.client6.v1.api.collections.batch.WeaviateBatchClient; import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClient; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClient; import io.weaviate.client6.v1.api.collections.generate.WeaviateGenerateClient; @@ -23,6 +25,7 @@ public class CollectionHandle { public final WeaviateAggregateClient aggregate; public final WeaviateGenerateClient generate; public final WeaviateTenantsClient tenants; + public final WeaviateBatchClient batch; private final CollectionHandleDefaults defaults; @@ -36,6 +39,7 @@ public CollectionHandle( this.query = new WeaviateQueryClient<>(collection, grpcTransport, defaults); this.generate = new WeaviateGenerateClient<>(collection, grpcTransport, defaults); this.data = new WeaviateDataClient<>(collection, restTransport, grpcTransport, defaults); + this.batch = new WeaviateBatchClient<>(grpcTransport, defaults); this.defaults = defaults; this.tenants = new WeaviateTenantsClient(collection, restTransport, grpcTransport); @@ -48,6 +52,7 @@ private CollectionHandle(CollectionHandle c, CollectionHandleDefaul this.query = new WeaviateQueryClient<>(c.query, defaults); this.generate = new WeaviateGenerateClient<>(c.generate, defaults); this.data = new WeaviateDataClient<>(c.data, defaults); + this.batch = new WeaviateBatchClient<>(c.batch, defaults); this.defaults = defaults; this.tenants = c.tenants; @@ -112,7 +117,7 @@ public long size() { } /** Default consistency level for requests. */ - public ConsistencyLevel consistencyLevel() { + public Optional consistencyLevel() { return defaults.consistencyLevel(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java index 83d18ed2f..0b29d2c82 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.api.collections; import java.util.Collection; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Function; @@ -90,7 +91,7 @@ public CompletableFuture size() { } /** Default consistency level for requests. */ - public ConsistencyLevel consistencyLevel() { + public Optional consistencyLevel() { return defaults.consistencyLevel(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java index c7952222a..f89e9f6f1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java @@ -1,14 +1,17 @@ package io.weaviate.client6.v1.api.collections; +import static java.util.Objects.requireNonNull; + import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.function.Function; import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.internal.ObjectBuilder; -public record CollectionHandleDefaults(ConsistencyLevel consistencyLevel, String tenant) { +public record CollectionHandleDefaults(Optional consistencyLevel, String tenant) { /** * Set default values for query / aggregation requests. * @@ -28,8 +31,12 @@ public static Function> none() return ObjectBuilder.identity(); } + public CollectionHandleDefaults { + requireNonNull(consistencyLevel, "consistencyLevel is null"); + } + public CollectionHandleDefaults(Builder builder) { - this(builder.consistencyLevel, builder.tenant); + this(Optional.of(builder.consistencyLevel), builder.tenant); } public static final class Builder implements ObjectBuilder { @@ -56,12 +63,12 @@ public CollectionHandleDefaults build() { /** Serialize default values to a URL query. */ public Map queryParameters() { - if (consistencyLevel == null && tenant == null) { + if (consistencyLevel.isEmpty() && tenant == null) { return Collections.emptyMap(); } var query = new HashMap(); - if (consistencyLevel != null) { - query.put("consistency_level", consistencyLevel); + if (consistencyLevel.isPresent()) { + query.put("consistency_level", consistencyLevel.get()); } if (tenant != null) { query.put("tenant", tenant); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java index 2fff8681a..601ea072d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java @@ -32,8 +32,8 @@ public static Rpc Rpc, WeaviateProtoBat request -> { var message = WeaviateProtoBatch.BatchObjectsRequest.newBuilder(); - var batch = request.objects.stream().map(obj -> { - var batchObject = WeaviateProtoBatch.BatchObject.newBuilder(); - buildObject(batchObject, obj, collection, defaults); - return batchObject.build(); - }).toList(); - + var batch = request.objects.stream() + .map(obj -> buildObject(obj, collection, defaults)) + .toList(); message.addAllObjects(batch); - if (defaults.consistencyLevel() != null) { - defaults.consistencyLevel().appendTo(message); + if (defaults.consistencyLevel().isPresent()) { + defaults.consistencyLevel().get().appendTo(message); } + var m = message.build(); + m.getSerializedSize(); return message.build(); }, response -> { @@ -92,10 +91,11 @@ public static Rpc, WeaviateProtoBat () -> WeaviateFutureStub::batchObjects); } - public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object, + public static WeaviateProtoBatch.BatchObject buildObject( WeaviateObject insert, CollectionDescriptor collection, CollectionHandleDefaults defaults) { + var object = WeaviateProtoBatch.BatchObject.newBuilder(); object.setCollection(collection.collectionName()); if (insert.uuid() != null) { @@ -158,6 +158,7 @@ public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object properties.setNonRefProperties(nonRef); } object.setProperties(properties); + return object.build(); } @SuppressWarnings("unchecked") diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java index 8588eb760..86840ed5d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java @@ -27,8 +27,8 @@ public static final Endpoint, Wea return new SimpleEndpoint<>( request -> "POST", request -> "/objects/", - request -> defaults.consistencyLevel() != null - ? Map.of("consistency_level", defaults.consistencyLevel()) + request -> defaults.consistencyLevel().isPresent() + ? Map.of("consistency_level", defaults.consistencyLevel().get()) : Collections.emptyMap(), request -> JSON.serialize( new WeaviateObject<>( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java index bb6a0f27e..822c5f54d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java @@ -67,6 +67,10 @@ public static ObjectReference[] collection(String collection, String... uuids) { .toArray(ObjectReference[]::new); } + public String beacon() { + return toBeacon(collection, uuid); + } + public static String toBeacon(String collection, String uuid) { return toBeacon(collection, null, uuid); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java index 284688daa..cfb2a4667 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java @@ -4,6 +4,7 @@ import java.util.List; import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; @@ -32,4 +33,7 @@ public static final Endpoint }); } + public static WeaviateProtoBatch.BatchReference buildReference(ObjectReference reference) { + return null; + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java index 13a1afacb..81dbb2428 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java @@ -26,8 +26,8 @@ static final Endpoint, Void> end return SimpleEndpoint.sideEffect( request -> "PUT", request -> "/objects/" + collection.collectionName() + "/" + request.object.uuid(), - request -> defaults.consistencyLevel() != null - ? Map.of("consistency_level", defaults.consistencyLevel()) + request -> defaults.consistencyLevel().isPresent() + ? Map.of("consistency_level", defaults.consistencyLevel().get()) : Collections.emptyMap(), request -> JSON.serialize( new WeaviateObject<>( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java index 6157a1cc8..cf0451d0d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java @@ -26,8 +26,8 @@ static final Endpoint, Void> endp return SimpleEndpoint.sideEffect( request -> "PATCH", request -> "/objects/" + collection.collectionName() + "/" + request.object.uuid(), - request -> defaults.consistencyLevel() != null - ? Map.of("consistency_level", defaults.consistencyLevel()) + request -> defaults.consistencyLevel().isPresent() + ? Map.of("consistency_level", defaults.consistencyLevel().get()) : Collections.emptyMap(), request -> JSON.serialize( new WeaviateObject<>( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java index 326609d2e..fc688cd41 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java @@ -43,6 +43,10 @@ public final void appendTo(WeaviateProtoBatch.BatchObjectsRequest.Builder req) { req.setConsistencyLevel(consistencyLevel); } + public final void appendTo(WeaviateProtoBatch.BatchStreamRequest.Start.Builder req) { + req.setConsistencyLevel(consistencyLevel); + } + @Override public String toString() { return queryParameter; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index 625dde30d..e723c2c03 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -40,8 +40,9 @@ public static WeaviateProtoSearchGet.SearchRequest marshal( if (defaults.tenant() != null) { message.setTenant(defaults.tenant()); } - if (defaults.consistencyLevel() != null) { - defaults.consistencyLevel().appendTo(message); + + if (defaults.consistencyLevel().isPresent()) { + defaults.consistencyLevel().get().appendTo(message); } if (request.groupBy != null) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java index b1bc7369e..751492591 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java @@ -3,7 +3,6 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.UUID; import java.util.stream.Stream; @@ -92,20 +91,9 @@ static WeaviateObject unmarshalWithReferences( (map, ref) -> { var refObjects = ref.getPropertiesList().stream() .map(property -> { - var reference = unmarshalWithReferences( + return (Reference) unmarshalWithReferences( property, property.getMetadata(), CollectionDescriptor.ofMap(property.getTargetCollection())); - return (Reference) new WeaviateObject<>( - reference.uuid(), - reference.collection(), - // TODO(dyma): we can get tenant from CollectionHandle - null, // tenant is not returned in the query - (Map) reference.properties(), - reference.vectors(), - reference.createdAt(), - reference.lastUpdatedAt(), - reference.queryMetadata(), - reference.references()); }) .toList(); diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java index 4b420b3fd..95d7cd126 100644 --- a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java @@ -24,7 +24,7 @@ public class CollectionHandleDefaultsTest { /** All defaults are {@code null} if none were set. */ @Test public void test_defaults() { - Assertions.assertThat(HANDLE_NONE.consistencyLevel()).as("default ConsistencyLevel").isNull(); + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).as("default ConsistencyLevel").isEmpty(); Assertions.assertThat(HANDLE_NONE.tenant()).as("default tenant").isNull(); } @@ -35,8 +35,8 @@ public void test_defaults() { @Test public void test_withConsistencyLevel() { var handle = HANDLE_NONE.withConsistencyLevel(ConsistencyLevel.QUORUM); - Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM); - Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull(); + Assertions.assertThat(handle.consistencyLevel()).get().isEqualTo(ConsistencyLevel.QUORUM); + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isEmpty(); } /** @@ -46,8 +46,8 @@ public void test_withConsistencyLevel() { @Test public void test_withConsistencyLevel_async() { var handle = HANDLE_NONE_ASYNC.withConsistencyLevel(ConsistencyLevel.QUORUM); - Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM); - Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isNull(); + Assertions.assertThat(handle.consistencyLevel()).get().isEqualTo(ConsistencyLevel.QUORUM); + Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isEmpty(); } /** @@ -58,7 +58,7 @@ public void test_withConsistencyLevel_async() { public void test_withTenant() { var handle = HANDLE_NONE.withTenant("john_doe"); Assertions.assertThat(handle.tenant()).isEqualTo("john_doe"); - Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull(); + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isEmpty(); } /** @@ -69,6 +69,6 @@ public void test_withTenant() { public void test_withTenant_async() { var handle = HANDLE_NONE_ASYNC.withTenant("john_doe"); Assertions.assertThat(handle.tenant()).isEqualTo("john_doe"); - Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isNull(); + Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isEmpty(); } } diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java index 9cf1e99d9..9ff578cfe 100644 --- a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java @@ -131,10 +131,10 @@ public void test_collectionHandleDefaults_rest(String __, rest.assertNext((method, requestUrl, body, query) -> { switch (clLoc) { case QUERY: - Assertions.assertThat(query).containsEntry("consistency_level", defaults.consistencyLevel()); + Assertions.assertThat(query).containsEntry("consistency_level", defaults.consistencyLevel().get()); break; case BODY: - assertJsonHasValue(body, "consistency_level", defaults.consistencyLevel()); + assertJsonHasValue(body, "consistency_level", defaults.consistencyLevel().get()); } switch (tenantLoc) { From df181cf89e6b2b285e58fac70fc61445caafb682 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Mon, 9 Feb 2026 21:35:27 +0100 Subject: [PATCH 02/22] wip(batch): add batch primitives (TaskHandle, Message, BatchContext --- .../api/collections/batch/BatchContext.java | 171 +++++++++++++ .../v1/api/collections/batch/Data.java | 92 +++++++ .../batch/DataTooBigException.java | 7 + .../v1/api/collections/batch/Event.java | 116 +++++++++ .../v1/api/collections/batch/Message.java | 239 ++++++++++++++++++ .../v1/api/collections/batch/State.java | 42 +++ .../api/collections/batch/StreamFactory.java | 7 + .../api/collections/batch/StreamMessage.java | 25 ++ .../v1/api/collections/batch/TaskHandle.java | 162 ++++++++++++ .../batch/TranslatingStreamFactory.java | 113 +++++++++ 10 files changed, 974 insertions(+) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java new file mode 100644 index 000000000..d56082ab6 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -0,0 +1,171 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +import javax.annotation.concurrent.GuardedBy; + +import io.grpc.stub.StreamObserver; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; + +/** + * BatchContext stores the state of an active batch process + * and controls its lifecycle. + * + *

State

+ * + *

Lifecycle

+ * + *

Cancellation policy

+ * + */ +public final class BatchContext implements Closeable { + private final int DEFAULT_BATCH_SIZE = 1000; + private final int DEFAULT_QUEUE_SIZE = 100; + + private final Optional consistencyLevel; + + /** + * Queue publishes insert tasks from the main thread to the "sender". + * It has a maximum capacity of {@link #DEFAULT_QUEUE_SIZE}. + * + * Send {@link TaskHandle#POISON} to gracefully shutdown the "sender" + * thread. The same queue may be re-used with a different "sender", + * e.g. after {@link #reconnect}, but only when the new thread is known + * to have started. Otherwise the thread trying to put an item on + * the queue will block indefinitely. + */ + private final BlockingQueue queue; + + /** + * wip stores work-in-progress items. + * + * An item is added to the {@link #wip} map after the Sender successfully + * adds it to the {@link #message} and is removed once the server reports + * back the result (whether success of failure). + */ + private final ConcurrentMap wip = new ConcurrentHashMap<>(); + + /** + * Message buffers batch items before they're sent to the server. + * + *

+ * An item is added to the {@link #message} after the Sender pulls it + * from the queue and remains there until it's Ack'ed. + */ + private final Message message; + + /** + * State encapsulates state-dependent behavior of the {@link BatchContext}. + * Before reading {@link #state}, a thread MUST acquire {@link #lock}. + */ + @GuardedBy("lock") + private State state; + /** lock synchronizes access to {@link #state}. */ + private final Lock lock = new ReentrantLock(); + /** stateChanged notifies threads about a state transition. */ + private final Condition stateChanged = lock.newCondition(); + + BatchContext(Optional consistencyLevel, int maxSizeBytes) { + this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); + this.message = new Message(DEFAULT_QUEUE_SIZE, maxSizeBytes); + this.consistencyLevel = consistencyLevel; + } + + void start() { + + } + + void reconnect() { + } + + /** Set the new state and notify awaiting threads. */ + void setState(State nextState) { + lock.lock(); + try { + state = nextState; + stateChanged.signal(); + } finally { + lock.unlock(); + } + } + + boolean canSend() { + lock.lock(); + try { + return state.canSend(); + } finally { + lock.unlock(); + } + } + + /** onEvent delegates event handling to {@link #state} */ + void onEvent(Event event) throws InterruptedException { + lock.lock(); + try { + state.onEvent(event); + } finally { + lock.unlock(); + } + } + + /** Add {@link WeaviateObject} to the batch. */ + public TaskHandle add() { + return null; + } + + public TaskHandle retry(TaskHandle taskHandle) { + return null; + } + + private final class Sender implements Runnable { + private final StreamObserver stream; + + private Sender(StreamObserver stream) { + this.stream = requireNonNull(stream, "stream is null"); + } + + @Override + public void run() { + throw new UnsupportedOperationException("implement!"); + } + } + + private final class Recv implements StreamObserver { + + @Override + public void onCompleted() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'onCompleted'"); + } + + @Override + public void onError(Throwable arg0) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'onError'"); + } + + @Override + public void onNext(Event arg0) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'onNext'"); + } + } + + @Override + public void close() throws IOException { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'close'"); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java new file mode 100644 index 000000000..4ed5753be --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java @@ -0,0 +1,92 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import javax.annotation.concurrent.Immutable; + +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.GeneratedMessage; +import com.google.protobuf.GeneratedMessageV3; + +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.ObjectReference; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; + +@Immutable +@SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 +class Data implements StreamMessage { + + /** + * Raw input value, as passed by the user. + */ + private final Object raw; + + /** + * Task ID. Depending on the underlying object, this will either be + * {@link WeaviateObject#uuid} or {@link ObjectReference#beacon}. + * + * Since UUIDs and beacons cannot clash, ID does not encode any + * information about the underlying data type. + */ + private final String id; + + /** + * Serialized representation of the {@link #raw}. This valus is immutable + * for the entire lifecycle of the handle. + */ + private final GeneratedMessage.ExtendableMessage message; + + /** Estimated size of the {@link #message} when serialized. */ + private final int sizeBytes; + + enum Type { + OBJECT(WeaviateProtoBatch.BatchStreamRequest.Data.OBJECTS_FIELD_NUMBER), + REFERENCE(WeaviateProtoBatch.BatchStreamRequest.Data.REFERENCES_FIELD_NUMBER); + + private final int fieldNumber; + + private Type(int fieldNumber) { + this.fieldNumber = fieldNumber; + } + } + + private Data(Object raw, String id, GeneratedMessage.ExtendableMessage message, int sizeBytes) { + this.raw = requireNonNull(raw, "raw is null"); + this.id = requireNonNull(id, "id is null"); + this.message = requireNonNull(message, "message is null"); + + assert sizeBytes >= 0; + this.sizeBytes = sizeBytes; + } + + static Data ofField(Object raw, String id, GeneratedMessage.ExtendableMessage message, + Type type) { + requireNonNull(type, "type is null"); + int sizeBytes = CodedOutputStream.computeMessageSize(type.fieldNumber, message); + return new Data(raw, id, message, sizeBytes); + } + + String id() { + return id; + } + + /** Serialized data size in bytes. */ + int sizeBytes() { + return sizeBytes; + } + + @Override + public void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder) { + WeaviateProtoBatch.BatchStreamRequest.Data.Builder data = builder.getDataBuilder(); + if (message instanceof WeaviateProtoBatch.BatchObject object) { + data.getObjectsBuilder().addValues(object); + } else if (message instanceof WeaviateProtoBatch.BatchReference ref) { + data.getReferencesBuilder().addValues(ref); + } + } + + @Override + public String toString() { + return "%s (%s)".formatted(raw.getClass().getSimpleName(), id); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java new file mode 100644 index 000000000..98ca81b49 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java @@ -0,0 +1,7 @@ +package io.weaviate.client6.v1.api.collections.batch; + +public class DataTooBigException extends Exception { + DataTooBigException(Data data, long maxSizeBytes) { + super("%s with size=%dB exceeds maximum message size %dB".formatted(data, data.sizeBytes(), maxSizeBytes)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java new file mode 100644 index 000000000..855187bd6 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java @@ -0,0 +1,116 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import io.weaviate.client6.v1.api.collections.batch.Event.Acks; +import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; +import io.weaviate.client6.v1.api.collections.batch.Event.Results; +import io.weaviate.client6.v1.api.collections.batch.Event.Started; +import io.weaviate.client6.v1.api.collections.batch.Event.TerminationEvent; + +sealed interface Event + permits Started, Acks, Results, Backoff, TerminationEvent { + + final static Event STARTED = new Started(); + final static Event OOM = TerminationEvent.OOM; + final static Event SHUTTING_DOWN = TerminationEvent.SHUTTING_DOWN; + final static Event SHUTDOWN = TerminationEvent.SHUTDOWN; + + /** */ + record Started() implements Event { + } + + /** + * The server has added items from the previous message to its internal + * work queue, client MAY send the next batch. + * + *

+ * The protocol guarantess that {@link Acks} will contain IDs for all + * items sent in the previous batch. + */ + record Acks(Collection acked) implements Event { + public Acks { + acked = List.copyOf(requireNonNull(acked, "acked is null")); + } + } + + /** + * Results for the insertion of a previous batches. + * + *

+ * We assume that the server may return partial results, or return + * results out of the order of inserting messages. + */ + record Results(Collection successful, Map errors) implements Event { + public Results { + successful = List.copyOf(requireNonNull(successful, "successful is null")); + errors = Map.copyOf(requireNonNull(errors, "errors is null")); + } + } + + /** + * Backoff communicates the optimal batch size (number of objects) + * with respect to the current load on the server. + * + *

+ * Backoff is an instruction, not a recommendation. + * On receiving this message, the client must ensure that + * all messages it produces, including the one being prepared, + * do not exceed the size limit indicated by {@link #maxSize} + * until the server sends another Backoff message. The limit + * MUST also be respected after a {@link BatchContext#reconnect}. + * + *

+ * The client MAY use the latest {@link #maxSize} as the default + * message limit in a new {@link BatchContext}, but is not required to. + */ + record Backoff(int maxSize) implements Event { + } + + enum TerminationEvent implements Event { + /** + * Out-Of-Memory. + * + *

+ * Items sent in the previous request cannot be accepted, + * as inserting them may exhaust server's available disk space. + * On receiving this message, the client MUST stop producing + * messages immediately and await {@link #SHUTTING_DOWN} event. + * + *

+ * {@link #OOM} is the sibling of {@link Acks} with the opposite effect. + * The protocol guarantees that the server will respond with either of + * the two, but never both. + */ + OOM, + + /** + * Server shutdown in progress. + * + *

+ * The server began the process of gracefull shutdown, due to a + * scale-up event (if it previously reported {@link #OOM}) or + * some other external event. + * On receiving this message, the client MUST stop producing + * messages immediately and close it's side of the stream. + */ + SHUTTING_DOWN, + + /** + * Server is shutdown. + * + *

+ * The server has finished the shutdown process and will not + * receive any messages. On receiving this message, the client + * MUST continue reading messages in the stream until the server + * closes it on its end, then re-connect to another instance + * by re-opening the stream and continue processing the batch. + */ + SHUTDOWN; + } + +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java new file mode 100644 index 000000000..832444a46 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java @@ -0,0 +1,239 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.Set; + +import javax.annotation.concurrent.GuardedBy; + +import com.google.protobuf.CodedOutputStream; +import com.nimbusds.jose.shaded.jcip.ThreadSafe; + +import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; +import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; + +/** + * Message, along with its items, can be in either of 2 states: + *

    + *
  • Prepared message accepts new items and can be resized. + *
  • In-flight message if sealed: it rejects new items and avoids otherwise + * modifying the {@link #buffer} until it's ack'ed. + *
+ * + *

Class invariants

+ * Buffer size cannot exceed {@link #maxSize}. + * + *

Synchronization policy

+ */ +@ThreadSafe +final class Message { + private static int DATA_TAG_SIZE = CodedOutputStream + .computeTagSize(WeaviateProtoBatch.BatchStreamRequest.DATA_FIELD_NUMBER); + + /** Backlog MUST be confined to the "receiver" thread. */ + private final List backlog = new ArrayList<>(); + + /** + * Items stored in this message. + */ + @GuardedBy("this") + private final LinkedHashMap buffer; + + /** + * Maximum number of items that can be added to the request. + * Must be greater that zero. + * + *

+ * This is determined by the server's {@link Backoff} instruction. + */ + @GuardedBy("this") + private int maxSize; + + /** + * Maximum size of the serialized message in bytes. + * Must be greater that zero. + * + *

+ * This is determined by the {@link GrpcChannelOptions#maxMessageSize}. + */ + @GuardedBy("this") + private final long maxSizeBytes; + + /** Total size of all values in the buffer. */ + @GuardedBy("this") + private long sizeBytes; + + @GuardedBy("this") + private boolean inFlight = false; + + @GuardedBy("this") + private OptionalInt pendingMaxSize = OptionalInt.empty(); + + Message(int maxSize, int maxSizeBytes) { + assert maxSize > 0 : "non-positive maxSize"; + + // A protobuf field has layout {@code [tag][lenght(payload)][payload]}, + // so to estimate the message size correctly we must account for "tag" + // and "length", not just the raw payload. + if (maxSizeBytes <= DATA_TAG_SIZE) { + throw new IllegalArgumentException("Maximum message size must be at least %dB".formatted(DATA_TAG_SIZE)); + } + this.maxSizeBytes = maxSizeBytes - DATA_TAG_SIZE; + this.maxSize = maxSize; + this.buffer = new LinkedHashMap<>(maxSize); // LinkedHashMap preserves insertion order. + + checkInvariants(); + } + + /** + * Returns true if message has reached its capacity, either in terms + * of the item count or the message's estimated size in bytes. + */ + synchronized boolean isFull() { + return buffer.size() == maxSize || sizeBytes == maxSizeBytes; + } + + /** + * Returns true if the message's internal buffer is empty. + * If it's primary buffer is empty, its backlog is guaranteed + * to be empty as well. + */ + synchronized boolean isEmpty() { + return buffer.isEmpty(); // sizeBytes == 0 is guaranteed by class invariant. + } + + /** + * Prepare a request to be sent. After calling this method, this message becomes + * "in-flight": an attempt to {@link #add} more items to it will be rejected + * with an exception. + */ + synchronized StreamMessage prepare() { + inFlight = true; + return builder -> { + buffer.forEach((__, data) -> { + data.appendTo(builder); + }); + }; + } + + synchronized void setMaxSize(int maxSizeNew) { + try { + // In-flight message cannot be modified. + // Store the requested maxSize for later; + // it will be applied on the next ack. + if (inFlight) { + pendingMaxSize = OptionalInt.of(maxSizeNew); + return; + } + + maxSize = maxSizeNew; + + // Buffer still fits under the new limit. + if (buffer.size() <= maxSize) { + return; + } + + // Buffer exceeds the new limit. + // Move extra items to the backlog in LIFO order. + Iterator> extra = buffer.reversed() + .entrySet().stream() + .limit(buffer.size() - maxSize) + .iterator(); + + while (extra.hasNext()) { + Map.Entry next = extra.next(); + backlog.add(next.getValue()); + extra.remove(); + } + // Reverse the backlog to restore the FIFO order. + Collections.reverse(backlog); + } finally { + checkInvariants(); + } + } + + /** + * Add a data item to the message. + * + * + * @throws DataTooBigException If the data exceeds the maximum + * possible message size. + * @throws IllegalStateException If called on an "in-flight" message. + * @see #prepare + * @see #inFlight + * + * @return Boolean indicating if the item has been accepted. + */ + synchronized boolean add(Data data) throws IllegalStateException, DataTooBigException { + requireNonNull(data, "data is null"); + + try { + if (inFlight) { + throw new IllegalStateException("Message is in-flight"); + } + if (data.sizeBytes() > maxSizeBytes - sizeBytes) { + if (isEmpty()) { + throw new DataTooBigException(data, maxSizeBytes); + } + return false; + } + addSafe(data); + return true; + } finally { + checkInvariants(); + } + } + + private synchronized void addSafe(Data data) { + buffer.put(data.id(), data); + sizeBytes += data.sizeBytes(); + } + + synchronized Iterable ack(Iterable acked) { + requireNonNull(acked, "acked are null"); + + try { + acked.forEach(id -> buffer.remove(id)); + Set remaining = Set.copyOf(buffer.keySet()); + + // Reset the in-flight status. + inFlight = false; + + // Populate message from the backlog. + // We don't need to check the return value of .add(), + // as all items in the backlog are guaranteed to not + // exceed maxSizeBytes. + backlog.stream() + .takeWhile(__ -> !isFull()) + .forEach(this::addSafe); + + return remaining; + } finally { + checkInvariants(); + } + } + + /** Asserts the invariants of this class. */ + private synchronized void checkInvariants() { + assert maxSize > 0 : "non-positive maxSize"; + assert maxSizeBytes > 0 : "non-positive maxSizeBytes"; + assert sizeBytes >= 0 : "negative sizeBytes"; + assert buffer.size() <= maxSize : "buffer exceeds maxSize"; + assert sizeBytes <= maxSizeBytes : "message exceeds maxSizeBytes"; + if (buffer.size() < maxSize) { + assert backlog.isEmpty() : "backlog not empty when buffer not full"; + } + if (buffer.isEmpty()) { + assert sizeBytes == 0 : "sizeBytes must be 0 when buffer is empty"; + } + assert pendingMaxSize != null : "pending max size is null"; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java new file mode 100644 index 000000000..9f016e728 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java @@ -0,0 +1,42 @@ +package io.weaviate.client6.v1.api.collections.batch; + +interface State { + /** + * canSend returns a boolean indicating if sending + * an "insert" message is allowed in this state. + */ + boolean canSend(); + + /** + * onEvent handles incoming events; these can be generated by the server + * or by a different part of the program -- the {@link State} MUST NOT + * make any assumptions about the event's origin. + * + *

+ * How the event is handled is up to the concrete implementation. + * It may modify {@link BatchContext} internal state, via one of it's + * package-private methods, including transitioning the context to a + * different state via {@link BatchContext#setState(State)}, or start + * a separate process, e.g. the OOM timer. + */ + void onEvent(Event event) throws InterruptedException; + + abstract class BaseState implements State { + @Override + public void onEvent(Event event) { + if (event instanceof Event.Acks acks) { + + } else if (event instanceof Event.Results results) { + + } else if (event instanceof Event.Backoff backoff) { + + } else if (event == Event.OOM) { + + } else if (event == Event.SHUTTING_DOWN) { + + } else { + throw new IllegalStateException("cannot handle " + event.getClass()); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java new file mode 100644 index 000000000..712b3e53c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java @@ -0,0 +1,7 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import io.grpc.stub.StreamObserver; + +interface StreamFactory { + StreamObserver createStream(StreamObserver out); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java new file mode 100644 index 000000000..9b999d6ba --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java @@ -0,0 +1,25 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import java.util.Optional; + +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; + +@FunctionalInterface +interface StreamMessage { + void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder); + + static StreamMessage start(Optional consistencyLevel) { + final WeaviateProtoBatch.BatchStreamRequest.Start.Builder start = WeaviateProtoBatch.BatchStreamRequest.Start + .newBuilder(); + consistencyLevel.ifPresent(value -> value.appendTo(start)); + return builder -> builder.setStart(start); + } + + static final StreamMessage STOP = builder -> builder + .setStop(WeaviateProtoBatch.BatchStreamRequest.Stop.getDefaultInstance()); + + static StreamMessage stop() { + return STOP; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java new file mode 100644 index 000000000..041e1a02a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java @@ -0,0 +1,162 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import javax.annotation.concurrent.ThreadSafe; + +import com.google.protobuf.GeneratedMessage; +import com.google.protobuf.GeneratedMessageV3; + +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.ObjectReference; + +@ThreadSafe +@SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 +public final class TaskHandle { + static final TaskHandle POISON = new TaskHandle(); + + /** + * Input value as passed by the user. + * + *

+ * Changes in the {@link #raw}'s underlying value will not be reflected + * in the {@link TaskHandle} (e.g. the serialized version is not updated), + * so users SHOULD treat items passed to and retrieved from {@link TaskHandle} + * as effectively ummodifiable. + */ + private final Data data; + + /** Flag indicatig the task has been ack'ed. */ + private final CompletableFuture acked = new CompletableFuture<>(); + + public final record Result(Optional error) { + public Result { + requireNonNull(error, "error is null"); + } + } + + /** + * Task result completes when the client receives {@link Event.Results} + * containing this handle's {@link #id}. + */ + private final CompletableFuture result = new CompletableFuture<>(); + + /** The number of times this task has been retried. */ + private final int retries; + + private TaskHandle(Data data, int retries) { + this.data = requireNonNull(data, "data is null"); + + assert retries >= 0 : "negative retries"; + this.retries = retries; + } + + /** Constructor for {@link WeaviateObject}. */ + TaskHandle(WeaviateObject object, GeneratedMessage.ExtendableMessage data) { + this(Data.ofField(object, object.uuid(), data, Data.Type.OBJECT), 0); + } + + /** Constructor for {@link ObjectReference}. */ + TaskHandle(ObjectReference reference, GeneratedMessage.ExtendableMessage data) { + this(Data.ofField(reference, reference.beacon(), data, Data.Type.REFERENCE), 0); + } + + /** + * Poison pill constructor. + * + *

+ * A handle created with this constructor should not be + * used for anything other that direct comparison using {@code ==} operator; + * calling any method on a poison pill is likely to result in a + * {@link NullPointerException} being thrown. + */ + private TaskHandle() { + this.data = null; + this.retries = 0; + } + + /** + * Creates a new task containing the same data as this task and {@link retries} + * counter incremented by 1. The {@link acked} and {@link result} futures + * are not copied to the returned task. + * + * @return Task handle. + */ + TaskHandle retry() { + return new TaskHandle(data, retries + 1); + } + + String id() { + return data.id(); + } + + /** Set the {@link #acked} flag. */ + void setAcked() { + acked.complete(null); + } + + /** + * Mark the task successful. This status cannot be changed, so calling + * {@link #setError} afterwards will have no effect. + */ + void setSuccess() { + setResult(new Result(Optional.empty())); + } + + /** + * Mark the task failed. This status cannot be changed, so calling + * {@link #setSuccess} afterwards will have no effect. + * + * @param error Error message. Null values are tolerated, but are only expected + * to occurr due to a server's mistake. + * Do not use {@code setError(null)} if the server reports success + * status for the task; prefer {@link #setSuccess} in that case. + */ + void setError(String error) { + setResult(new Result(Optional.ofNullable(error))); + } + + /** + * Set result for this task. + * + * @throws IllegalStateException if the task has not been ack'ed. + */ + private void setResult(Result result) { + if (!acked.isDone()) { + throw new IllegalStateException("Result can only be set for an ack'ed task"); + } + this.result.complete(result); + } + + /** + * Check if the task has been accepted. + * + * @return A future which completes when the server has accepted the task. + */ + public CompletableFuture isAcked() { + return acked; + } + + /** + * Retrieve the result for this task. + * + * @return A future which completes when the server + * has reported the result for this task. + */ + public CompletableFuture result() { + return result; + } + + /** + * Number of times this task has been retried. Since {@link TaskHandle} is + * immutable, this value does not change, but retrying a task via + * {@link BatchContext#retry} is reflected in the returned handle's + * {@link #timesRetried}. + */ + public int timesRetried() { + return retries; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java new file mode 100644 index 000000000..edbee0755 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java @@ -0,0 +1,113 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import io.grpc.stub.StreamObserver; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; + +class TranslatingStreamFactory implements StreamFactory { + private final StreamFactory protoFactory; + + TranslatingStreamFactory( + StreamFactory protoFactory) { + this.protoFactory = requireNonNull(protoFactory, "protoFactory is null"); + } + + @Override + public StreamObserver createStream(StreamObserver eventObserver) { + return new MessageProducer(protoFactory.createStream(new EventHandler(eventObserver))); + } + + private abstract class DelegatingStreamObserver implements StreamObserver { + protected final StreamObserver delegate; + + protected DelegatingStreamObserver(StreamObserver delegate) { + this.delegate = delegate; + } + + @Override + public void onCompleted() { + delegate.onCompleted(); + } + + @Override + public void onError(Throwable t) { + delegate.onError(t); + } + } + + private final class MessageProducer + extends DelegatingStreamObserver { + private MessageProducer(StreamObserver delegate) { + super(delegate); + } + + @Override + public void onNext(StreamMessage message) { + WeaviateProtoBatch.BatchStreamRequest.Builder builder = WeaviateProtoBatch.BatchStreamRequest.newBuilder(); + message.appendTo(builder); + delegate.onNext(builder.build()); + } + } + + private final class EventHandler extends DelegatingStreamObserver { + private EventHandler(StreamObserver delegate) { + super(delegate); + } + + @Override + public void onNext(BatchStreamReply reply) { + Event event = null; + switch (reply.getMessageCase()) { + case STARTED: + event = Event.STARTED; + case SHUTTING_DOWN: + event = Event.SHUTTING_DOWN; + case SHUTDOWN: + event = Event.SHUTDOWN; + case OUT_OF_MEMORY: + event = Event.OOM; + case BACKOFF: + event = new Event.Backoff(reply.getBackoff().getBatchSize()); + case ACKS: + Stream uuids = reply.getAcks().getUuidsList().stream(); + Stream beacons = reply.getAcks().getBeaconsList().stream(); + event = new Event.Acks(Stream.concat(uuids, beacons).toList()); + case RESULTS: + List successful = reply.getResults().getSuccessesList().stream() + .map(detail -> { + if (detail.hasUuid()) { + return detail.getUuid(); + } else if (detail.hasBeacon()) { + return detail.getBeacon(); + } + throw new IllegalArgumentException("Result has neither UUID nor a beacon"); + }) + .toList(); + + Map errors = reply.getResults().getErrorsList().stream() + .map(detail -> { + String error = requireNonNull(detail.getError(), "error is null"); + if (detail.hasUuid()) { + return Map.entry(detail.getUuid(), error); + } else if (detail.hasBeacon()) { + return Map.entry(detail.getBeacon(), error); + } + throw new IllegalArgumentException("Result has neither UUID nor a beacon"); + }) + .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + event = new Event.Results(successful, errors); + case MESSAGE_NOT_SET: + throw new IllegalArgumentException("Message not set"); + } + + delegate.onNext(event); + } + } +} From 2db0147ee2120ac01bc8bb0a9c6f1f7bf0156406 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 10 Feb 2026 11:10:57 +0100 Subject: [PATCH 03/22] feat(batch): start a new batch from CollectionHandle Extended GrpcTransport interface to implements StreamFactory. Added documentation to batch's primitives. Renamed Message -> Batch, StreamMessage -> Message, MessageProducer -> Messeger, and EventHandler -> Eventer. Extracted message size calculation into MessageSizeUtil. --- .../io/weaviate/client6/v1/api/Config.java | 4 +- .../v1/api/collections/batch/Batch.java | 311 ++++++++++++++++++ .../api/collections/batch/BatchContext.java | 31 +- .../v1/api/collections/batch/Data.java | 16 +- .../batch/DataTooBigException.java | 4 + .../v1/api/collections/batch/Message.java | 241 +------------- .../collections/batch/MessageSizeUtil.java | 44 +++ .../v1/api/collections/batch/State.java | 4 + .../api/collections/batch/StreamFactory.java | 8 +- .../api/collections/batch/StreamMessage.java | 25 -- .../v1/api/collections/batch/TaskHandle.java | 4 +- .../batch/TranslatingStreamFactory.java | 44 ++- .../batch/WeaviateBatchClient.java | 33 ++ .../collections/data/WeaviateDataClient.java | 2 +- .../internal/grpc/DefaultGrpcTransport.java | 76 +++-- .../v1/internal/grpc/GrpcChannelOptions.java | 10 +- .../v1/internal/grpc/GrpcTransport.java | 27 +- .../testutil/transport/MockGrpcTransport.java | 16 + 18 files changed, 585 insertions(+), 315 deletions(-) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java delete mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index b52b37066..116454893 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -193,8 +193,8 @@ private static boolean isWeaviateDomain(String host) { } private static final String VERSION = "weaviate-client-java/" - + ((!BuildInfo.TAGS.isBlank() && BuildInfo.TAGS != "unknown") ? BuildInfo.TAGS - : (BuildInfo.BRANCH + "-" + BuildInfo.COMMIT_ID_ABBREV)); + + ((!BuildInfo.TAGS.isBlank() && BuildInfo.TAGS != "unknown") ? BuildInfo.TAGS + : (BuildInfo.BRANCH + "-" + BuildInfo.COMMIT_ID_ABBREV)); @Override public Config build() { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java new file mode 100644 index 000000000..c4cfc479f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java @@ -0,0 +1,311 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.Set; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; +import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; + +// assert maxSize > 0 : "non-positive maxSize"; +// assert maxSizeBytes > 0 : "non-positive maxSizeBytes"; +// assert sizeBytes >= 0 : "negative sizeBytes"; +// assert buffer.size() <= maxSize : "buffer exceeds maxSize"; +// assert sizeBytes <= maxSizeBytes : "message exceeds maxSizeBytes"; +// if (buffer.size() < maxSize) { +// assert backlog.isEmpty() : "backlog not empty when buffer not full"; +// } +// if (buffer.isEmpty()) { +// assert sizeBytes == 0 : "sizeBytes must be 0 when buffer is empty"; +// } +// assert pendingMaxSize != null : "pending max size is null"; + +/** + * Batch can be in either of 2 states: + *

    + *
  • Open batch accepts new items and can be resized. + *
  • In-flight batch is sealed: it rejects new items and + * avoids otherwise modifying the {@link #buffer} until it's cleared. + *
+ * + *

Class invariants

+ * + * {@link #maxSize} and {@link #maxSizeBytes} MUST be positive. + * A batch with {@code cap=0} is not useful.
+ * {@link #buffer} size and {@link #sizeBytes} MUST be non-negative.
+ * {@link #buffer} size MUST NOT exceed {@link #maxSize}.
+ * {@link #sizeBytes} MUST NOT exceed {@link #maxSize}.
+ * {@link #sizeBytes} MUST be 0 if the buffer is full.
+ * {@link #backlog} MAY only contain items when {@link #buffer} is full. In the + * {@link #pendingMaxSize} is empty for an open batch.
+ * edge-case + * + * + *

Synchronization policy

+ * + * @see #inFlight + * @see #isFull + * @see #clear + * @see #checkInvariants + */ +@ThreadSafe +final class Batch { + /** Backlog MUST be confined to the "receiver" thread. */ + private final List backlog = new ArrayList<>(); + + /** + * Items stored in this batch. + */ + @GuardedBy("this") + private final LinkedHashMap buffer; + + /** + * Maximum number of items that can be added to the request. + * Must be greater that zero. + * + *

+ * This is determined by the server's {@link Backoff} instruction. + */ + @GuardedBy("this") + private int maxSize; + + /** + * Maximum size of the serialized message in bytes. + * Must be greater that zero. + * + *

+ * This is determined by the {@link GrpcChannelOptions#maxMessageSize}. + */ + @GuardedBy("this") + private final long maxSizeBytes; + + /** Total serialized size of the items in the {@link #buffer}. */ + @GuardedBy("this") + private long sizeBytes; + + /** An in-flight batch is unmodifiable. */ + @GuardedBy("this") + private boolean inFlight = false; + + /** + * Pending update to the {@link #maxSize}. + * + * The value is non-empty when {@link #setMaxSize} is called + * while the batch is {@link #inFlight}. + */ + @GuardedBy("this") + private OptionalInt pendingMaxSize = OptionalInt.empty(); + + Batch(int maxSize, int maxSizeBytes) { + assert maxSize > 0 : "non-positive maxSize"; + + this.maxSizeBytes = MessageSizeUtil.maxSizeBytes(maxSizeBytes); + this.maxSize = maxSize; + this.buffer = new LinkedHashMap<>(maxSize); // LinkedHashMap preserves insertion order. + + checkInvariants(); + } + + /** + * Returns true if batch has reached its capacity, either in terms + * of the item count or the batch's estimated size in bytes. + */ + synchronized boolean isFull() { + return buffer.size() == maxSize || sizeBytes == maxSizeBytes; + } + + /** + * Returns true if the batch's internal buffer is empty. + * If it's primary buffer is empty, its backlog is guaranteed + * to be empty as well. + */ + synchronized boolean isEmpty() { + return buffer.isEmpty(); // sizeBytes == 0 is guaranteed by class invariant. + } + + /** + * Prepare a request to be sent. After calling this method, this batch becomes + * "in-flight": an attempt to {@link #add} more items to it will be rejected + * with an exception. + */ + synchronized Message prepare() { + checkInvariants(); + + inFlight = true; + return builder -> { + buffer.forEach((__, data) -> { + data.appendTo(builder); + }); + }; + } + + /** + * Set the new {@link #maxSize} for this buffer. + * + *

+ * How the size is applied depends of the buffer's current state: + *

    + *
  • When the batch is in-flight, the new limit is stored in + * {@link #pendingMaxSize} and will be applied once the batch is cleared. + *
  • While the batch is still open, the new limit is applied immediately and + * the {@link #pendingMaxSize} is set back to {@link OptionalInt#empty}. If + * the current buffer size exceeds the new limit, the overflow items are moved + * to the {@link #backlog}. + *
+ * + * @param maxSizeNew New batch size limit. + * + * @see #clear + */ + synchronized void setMaxSize(int maxSizeNew) { + checkInvariants(); + + try { + // In-flight batch cannot be modified. + // Store the requested maxSize for later; + // it will be applied on the next ack. + if (inFlight) { + pendingMaxSize = OptionalInt.of(maxSizeNew); + return; + } + + maxSize = maxSizeNew; + pendingMaxSize = OptionalInt.empty(); + + // Buffer still fits under the new limit. + if (buffer.size() <= maxSize) { + return; + } + + // Buffer exceeds the new limit. + // Move extra items to the backlog in LIFO order. + Iterator> extra = buffer.reversed() + .entrySet().stream() + .limit(buffer.size() - maxSize) + .iterator(); + + while (extra.hasNext()) { + Map.Entry next = extra.next(); + backlog.add(next.getValue()); + extra.remove(); + } + // Reverse the backlog to restore the FIFO order. + Collections.reverse(backlog); + } finally { + checkInvariants(); + } + } + + /** + * Add a data item to the batch. + * + * + * @throws DataTooBigException If the data exceeds the maximum + * possible batch size. + * @throws IllegalStateException If called on an "in-flight" batch. + * @see #prepare + * @see #inFlight + * + * @return Boolean indicating if the item has been accepted. + */ + synchronized boolean add(Data data) throws IllegalStateException, DataTooBigException { + requireNonNull(data, "data is null"); + checkInvariants(); + + try { + if (inFlight) { + throw new IllegalStateException("Batch is in-flight"); + } + if (data.sizeBytes() > maxSizeBytes - sizeBytes) { + if (isEmpty()) { + throw new DataTooBigException(data, maxSizeBytes); + } + return false; + } + addSafe(data); + return true; + } finally { + checkInvariants(); + } + } + + /** + * Add a data item to the batch. + * + * This method does not check {@link Data#sizeBytes}, so the caller + * must ensure that this item will not overflow the batch. + */ + private synchronized void addSafe(Data data) { + buffer.put(data.id(), data); + sizeBytes += data.sizeBytes(); + } + + /** + * Clear this batch's internal buffer. + * + *

+ * Once the buffer is pruned, it is re-populated from the backlog + * until the former is full or the latter is exhaused. + * If {@link #pendingMaxSize} is not empty, it is applied + * before re-populating the buffer. + * + * @return IDs removed from the buffer. + */ + synchronized Collection clear() { + checkInvariants(); + + try { + inFlight = false; + + Set removed = Set.copyOf(buffer.keySet()); + buffer.clear(); + + if (pendingMaxSize.isPresent()) { + setMaxSize(pendingMaxSize.getAsInt()); + } + + // Populate internal buffer from the backlog. + // We don't need to check the return value of .add(), + // as all items in the backlog are guaranteed to not + // exceed maxSizeBytes. + backlog.stream() + .takeWhile(__ -> !isFull()) + .forEach(this::addSafe); + + return removed; + } finally { + checkInvariants(); + } + } + + /** Asserts the invariants of this class. */ + private synchronized void checkInvariants() { + assert maxSize > 0 : "non-positive maxSize"; + assert maxSizeBytes > 0 : "non-positive maxSizeBytes"; + assert sizeBytes >= 0 : "negative sizeBytes"; + assert buffer.size() <= maxSize : "buffer exceeds maxSize"; + assert sizeBytes <= maxSizeBytes : "message exceeds maxSizeBytes"; + if (!isFull()) { + assert backlog.isEmpty() : "backlog not empty when buffer not full"; + } + if (buffer.isEmpty()) { + assert sizeBytes == 0 : "sizeBytes must be 0 when buffer is empty"; + } + + requireNonNull(pendingMaxSize, "pendingMaxSize is null"); + if (!inFlight) { + assert pendingMaxSize.isEmpty() : "open batch has pending maxSize"; + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index d56082ab6..167087fa2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -29,11 +29,13 @@ * *

Cancellation policy

* + * @param the shape of properties for inserted objects. */ -public final class BatchContext implements Closeable { +public final class BatchContext implements Closeable { private final int DEFAULT_BATCH_SIZE = 1000; private final int DEFAULT_QUEUE_SIZE = 100; + private final StreamFactory streamFactory; private final Optional consistencyLevel; /** @@ -49,22 +51,22 @@ public final class BatchContext implements Closeable { private final BlockingQueue queue; /** - * wip stores work-in-progress items. + * Work-in-progress items. * * An item is added to the {@link #wip} map after the Sender successfully - * adds it to the {@link #message} and is removed once the server reports + * adds it to the {@link #batch} and is removed once the server reports * back the result (whether success of failure). */ private final ConcurrentMap wip = new ConcurrentHashMap<>(); /** - * Message buffers batch items before they're sent to the server. + * Current batch. * *

- * An item is added to the {@link #message} after the Sender pulls it + * An item is added to the {@link #batch} after the Sender pulls it * from the queue and remains there until it's Ack'ed. */ - private final Message message; + private final Batch batch; /** * State encapsulates state-dependent behavior of the {@link BatchContext}. @@ -77,10 +79,15 @@ public final class BatchContext implements Closeable { /** stateChanged notifies threads about a state transition. */ private final Condition stateChanged = lock.newCondition(); - BatchContext(Optional consistencyLevel, int maxSizeBytes) { + BatchContext( + StreamFactory streamFactory, + int maxSizeBytes, + Optional consistencyLevel) { + this.streamFactory = requireNonNull(streamFactory, "streamFactory is null"); + this.consistencyLevel = requireNonNull(consistencyLevel, "consistencyLevel is null"); + this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); - this.message = new Message(DEFAULT_QUEUE_SIZE, maxSizeBytes); - this.consistencyLevel = consistencyLevel; + this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes); } void start() { @@ -92,6 +99,8 @@ void reconnect() { /** Set the new state and notify awaiting threads. */ void setState(State nextState) { + requireNonNull(nextState, "nextState is null"); + lock.lock(); try { state = nextState; @@ -130,9 +139,9 @@ public TaskHandle retry(TaskHandle taskHandle) { } private final class Sender implements Runnable { - private final StreamObserver stream; + private final StreamObserver stream; - private Sender(StreamObserver stream) { + private Sender(StreamObserver stream) { this.stream = requireNonNull(stream, "stream is null"); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java index 4ed5753be..dba9e79ad 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java @@ -4,7 +4,6 @@ import javax.annotation.concurrent.Immutable; -import com.google.protobuf.CodedOutputStream; import com.google.protobuf.GeneratedMessage; import com.google.protobuf.GeneratedMessageV3; @@ -14,7 +13,7 @@ @Immutable @SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 -class Data implements StreamMessage { +class Data implements Message { /** * Raw input value, as passed by the user. @@ -48,6 +47,10 @@ enum Type { private Type(int fieldNumber) { this.fieldNumber = fieldNumber; } + + public int fieldNumber() { + return fieldNumber; + } } private Data(Object raw, String id, GeneratedMessage.ExtendableMessage message, int sizeBytes) { @@ -59,11 +62,9 @@ private Data(Object raw, String id, GeneratedMessage.ExtendableMessage message, + Data(Object raw, String id, GeneratedMessage.ExtendableMessage message, Type type) { - requireNonNull(type, "type is null"); - int sizeBytes = CodedOutputStream.computeMessageSize(type.fieldNumber, message); - return new Data(raw, id, message, sizeBytes); + this(raw, id, message, MessageSizeUtil.ofDataField(message, type)); } String id() { @@ -77,7 +78,8 @@ int sizeBytes() { @Override public void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder) { - WeaviateProtoBatch.BatchStreamRequest.Data.Builder data = builder.getDataBuilder(); + WeaviateProtoBatch.BatchStreamRequest.Data.Builder data = requireNonNull(builder, "builder is null") + .getDataBuilder(); if (message instanceof WeaviateProtoBatch.BatchObject object) { data.getObjectsBuilder().addValues(object); } else if (message instanceof WeaviateProtoBatch.BatchReference ref) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java index 98ca81b49..e6adb4c12 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java @@ -1,5 +1,9 @@ package io.weaviate.client6.v1.api.collections.batch; +/** + * DataTooBigException is thrown when a single object exceeds + * the maximum size of a gRPC message. + */ public class DataTooBigException extends Exception { DataTooBigException(Data data, long maxSizeBytes) { super("%s with size=%dB exceeds maximum message size %dB".formatted(data, data.sizeBytes(), maxSizeBytes)); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java index 832444a46..9d387d1b1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java @@ -2,238 +2,31 @@ import static java.util.Objects.requireNonNull; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.OptionalInt; -import java.util.Set; +import java.util.Optional; -import javax.annotation.concurrent.GuardedBy; - -import com.google.protobuf.CodedOutputStream; -import com.nimbusds.jose.shaded.jcip.ThreadSafe; - -import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; -import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; -/** - * Message, along with its items, can be in either of 2 states: - *

    - *
  • Prepared message accepts new items and can be resized. - *
  • In-flight message if sealed: it rejects new items and avoids otherwise - * modifying the {@link #buffer} until it's ack'ed. - *
- * - *

Class invariants

- * Buffer size cannot exceed {@link #maxSize}. - * - *

Synchronization policy

- */ -@ThreadSafe -final class Message { - private static int DATA_TAG_SIZE = CodedOutputStream - .computeTagSize(WeaviateProtoBatch.BatchStreamRequest.DATA_FIELD_NUMBER); - - /** Backlog MUST be confined to the "receiver" thread. */ - private final List backlog = new ArrayList<>(); - - /** - * Items stored in this message. - */ - @GuardedBy("this") - private final LinkedHashMap buffer; - - /** - * Maximum number of items that can be added to the request. - * Must be greater that zero. - * - *

- * This is determined by the server's {@link Backoff} instruction. - */ - @GuardedBy("this") - private int maxSize; - - /** - * Maximum size of the serialized message in bytes. - * Must be greater that zero. - * - *

- * This is determined by the {@link GrpcChannelOptions#maxMessageSize}. - */ - @GuardedBy("this") - private final long maxSizeBytes; - - /** Total size of all values in the buffer. */ - @GuardedBy("this") - private long sizeBytes; - - @GuardedBy("this") - private boolean inFlight = false; +@FunctionalInterface +interface Message { + void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder); - @GuardedBy("this") - private OptionalInt pendingMaxSize = OptionalInt.empty(); + /** Create a Start message. */ + static Message start(Optional consistencyLevel) { + requireNonNull(consistencyLevel, "consistencyLevel is null"); - Message(int maxSize, int maxSizeBytes) { - assert maxSize > 0 : "non-positive maxSize"; - - // A protobuf field has layout {@code [tag][lenght(payload)][payload]}, - // so to estimate the message size correctly we must account for "tag" - // and "length", not just the raw payload. - if (maxSizeBytes <= DATA_TAG_SIZE) { - throw new IllegalArgumentException("Maximum message size must be at least %dB".formatted(DATA_TAG_SIZE)); - } - this.maxSizeBytes = maxSizeBytes - DATA_TAG_SIZE; - this.maxSize = maxSize; - this.buffer = new LinkedHashMap<>(maxSize); // LinkedHashMap preserves insertion order. - - checkInvariants(); - } - - /** - * Returns true if message has reached its capacity, either in terms - * of the item count or the message's estimated size in bytes. - */ - synchronized boolean isFull() { - return buffer.size() == maxSize || sizeBytes == maxSizeBytes; + final WeaviateProtoBatch.BatchStreamRequest.Start.Builder start = WeaviateProtoBatch.BatchStreamRequest.Start + .newBuilder(); + consistencyLevel.ifPresent(value -> value.appendTo(start)); + return builder -> builder.setStart(start); } - /** - * Returns true if the message's internal buffer is empty. - * If it's primary buffer is empty, its backlog is guaranteed - * to be empty as well. - */ - synchronized boolean isEmpty() { - return buffer.isEmpty(); // sizeBytes == 0 is guaranteed by class invariant. + /** Create a Stop message. */ + static Message stop() { + return STOP; } - /** - * Prepare a request to be sent. After calling this method, this message becomes - * "in-flight": an attempt to {@link #add} more items to it will be rejected - * with an exception. - */ - synchronized StreamMessage prepare() { - inFlight = true; - return builder -> { - buffer.forEach((__, data) -> { - data.appendTo(builder); - }); - }; - } - - synchronized void setMaxSize(int maxSizeNew) { - try { - // In-flight message cannot be modified. - // Store the requested maxSize for later; - // it will be applied on the next ack. - if (inFlight) { - pendingMaxSize = OptionalInt.of(maxSizeNew); - return; - } - - maxSize = maxSizeNew; - - // Buffer still fits under the new limit. - if (buffer.size() <= maxSize) { - return; - } + static final Message STOP = builder -> builder + .setStop(WeaviateProtoBatch.BatchStreamRequest.Stop.getDefaultInstance()); - // Buffer exceeds the new limit. - // Move extra items to the backlog in LIFO order. - Iterator> extra = buffer.reversed() - .entrySet().stream() - .limit(buffer.size() - maxSize) - .iterator(); - - while (extra.hasNext()) { - Map.Entry next = extra.next(); - backlog.add(next.getValue()); - extra.remove(); - } - // Reverse the backlog to restore the FIFO order. - Collections.reverse(backlog); - } finally { - checkInvariants(); - } - } - - /** - * Add a data item to the message. - * - * - * @throws DataTooBigException If the data exceeds the maximum - * possible message size. - * @throws IllegalStateException If called on an "in-flight" message. - * @see #prepare - * @see #inFlight - * - * @return Boolean indicating if the item has been accepted. - */ - synchronized boolean add(Data data) throws IllegalStateException, DataTooBigException { - requireNonNull(data, "data is null"); - - try { - if (inFlight) { - throw new IllegalStateException("Message is in-flight"); - } - if (data.sizeBytes() > maxSizeBytes - sizeBytes) { - if (isEmpty()) { - throw new DataTooBigException(data, maxSizeBytes); - } - return false; - } - addSafe(data); - return true; - } finally { - checkInvariants(); - } - } - - private synchronized void addSafe(Data data) { - buffer.put(data.id(), data); - sizeBytes += data.sizeBytes(); - } - - synchronized Iterable ack(Iterable acked) { - requireNonNull(acked, "acked are null"); - - try { - acked.forEach(id -> buffer.remove(id)); - Set remaining = Set.copyOf(buffer.keySet()); - - // Reset the in-flight status. - inFlight = false; - - // Populate message from the backlog. - // We don't need to check the return value of .add(), - // as all items in the backlog are guaranteed to not - // exceed maxSizeBytes. - backlog.stream() - .takeWhile(__ -> !isFull()) - .forEach(this::addSafe); - - return remaining; - } finally { - checkInvariants(); - } - } - - /** Asserts the invariants of this class. */ - private synchronized void checkInvariants() { - assert maxSize > 0 : "non-positive maxSize"; - assert maxSizeBytes > 0 : "non-positive maxSizeBytes"; - assert sizeBytes >= 0 : "negative sizeBytes"; - assert buffer.size() <= maxSize : "buffer exceeds maxSize"; - assert sizeBytes <= maxSizeBytes : "message exceeds maxSizeBytes"; - if (buffer.size() < maxSize) { - assert backlog.isEmpty() : "backlog not empty when buffer not full"; - } - if (buffer.isEmpty()) { - assert sizeBytes == 0 : "sizeBytes must be 0 when buffer is empty"; - } - assert pendingMaxSize != null : "pending max size is null"; - } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java new file mode 100644 index 000000000..f3c934232 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java @@ -0,0 +1,44 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.GeneratedMessage; +import com.google.protobuf.GeneratedMessageV3; + +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; + +final class MessageSizeUtil { + private static int DATA_TAG_SIZE = CodedOutputStream + .computeTagSize(WeaviateProtoBatch.BatchStreamRequest.DATA_FIELD_NUMBER); + + private MessageSizeUtil() { + } + + /** + * Adjust batch byte-size limit to account for the + * {@link WeaviateProtoBatch.BatchStreamRequest.Data} container. + * + *

+ * A protobuf field has layout {@code [tag][lenght(payload)][payload]}, + * so to estimate the batch size correctly we must account for "tag" + * and "length", not just the raw payload. + */ + static long maxSizeBytes(long maxSizeBytes) { + if (maxSizeBytes <= DATA_TAG_SIZE) { + throw new IllegalArgumentException("Maximum batch size must be at least %dB".formatted(DATA_TAG_SIZE)); + } + return maxSizeBytes - DATA_TAG_SIZE; + } + + /** + * Calculate the size of a serialized + * {@link WeaviateProtoBatch.BatchStreamRequest.Data} field. + */ + @SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 + static int ofDataField(GeneratedMessage.ExtendableMessage message, Data.Type type) { + requireNonNull(type, "type is null"); + requireNonNull(message, "message is null"); + return CodedOutputStream.computeMessageSize(type.fieldNumber(), message); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java index 9f016e728..52a2adab6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java @@ -1,5 +1,7 @@ package io.weaviate.client6.v1.api.collections.batch; +import static java.util.Objects.requireNonNull; + interface State { /** * canSend returns a boolean indicating if sending @@ -24,6 +26,8 @@ interface State { abstract class BaseState implements State { @Override public void onEvent(Event event) { + requireNonNull(event, "event is null"); + if (event instanceof Event.Acks acks) { } else if (event instanceof Event.Results results) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java index 712b3e53c..c50eec6ee 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java @@ -2,6 +2,12 @@ import io.grpc.stub.StreamObserver; +/** + * @param the type of the object sent down the stream. + * @param the type of the object received from the stream. + */ +@FunctionalInterface interface StreamFactory { - StreamObserver createStream(StreamObserver out); + /** Create a new stream for the send-recv observer pair. */ + StreamObserver createStream(StreamObserver recv); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java deleted file mode 100644 index 9b999d6ba..000000000 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamMessage.java +++ /dev/null @@ -1,25 +0,0 @@ -package io.weaviate.client6.v1.api.collections.batch; - -import java.util.Optional; - -import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; - -@FunctionalInterface -interface StreamMessage { - void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder); - - static StreamMessage start(Optional consistencyLevel) { - final WeaviateProtoBatch.BatchStreamRequest.Start.Builder start = WeaviateProtoBatch.BatchStreamRequest.Start - .newBuilder(); - consistencyLevel.ifPresent(value -> value.appendTo(start)); - return builder -> builder.setStart(start); - } - - static final StreamMessage STOP = builder -> builder - .setStop(WeaviateProtoBatch.BatchStreamRequest.Stop.getDefaultInstance()); - - static StreamMessage stop() { - return STOP; - } -} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java index 041e1a02a..e7c6799cc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java @@ -56,12 +56,12 @@ private TaskHandle(Data data, int retries) { /** Constructor for {@link WeaviateObject}. */ TaskHandle(WeaviateObject object, GeneratedMessage.ExtendableMessage data) { - this(Data.ofField(object, object.uuid(), data, Data.Type.OBJECT), 0); + this(new Data(object, object.uuid(), data, Data.Type.OBJECT), 0); } /** Constructor for {@link ObjectReference}. */ TaskHandle(ObjectReference reference, GeneratedMessage.ExtendableMessage data) { - this(Data.ofField(reference, reference.beacon(), data, Data.Type.REFERENCE), 0); + this(new Data(reference, reference.beacon(), data, Data.Type.REFERENCE), 0); } /** diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java index edbee0755..530ed31e8 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java @@ -8,10 +8,20 @@ import java.util.stream.Stream; import io.grpc.stub.StreamObserver; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; -class TranslatingStreamFactory implements StreamFactory { +/** + * TranslatingStreamFactory is an adaptor for the + * {@link WeaviateGrpc.WeaviateStub#batchStream} factory. The returned stream + * translates client-side messages into protobuf requests and server-side + * replies into events. + * + * @see Message + * @see Event + */ +class TranslatingStreamFactory implements StreamFactory { private final StreamFactory protoFactory; TranslatingStreamFactory( @@ -20,10 +30,17 @@ class TranslatingStreamFactory implements StreamFactory { } @Override - public StreamObserver createStream(StreamObserver eventObserver) { - return new MessageProducer(protoFactory.createStream(new EventHandler(eventObserver))); + public StreamObserver createStream(StreamObserver recv) { + return new Messeger(protoFactory.createStream(new Eventer(recv))); } + /** + * DelegatingStreamObserver delegates {@link #onCompleted} and {@link #onError} + * to another observer and translates the messages in {@link #onNext}. + * + * @param the type of the incoming message. + * @param the type of the message handed to the delegate. + */ private abstract class DelegatingStreamObserver implements StreamObserver { protected final StreamObserver delegate; @@ -42,22 +59,31 @@ public void onError(Throwable t) { } } - private final class MessageProducer - extends DelegatingStreamObserver { - private MessageProducer(StreamObserver delegate) { + /** + * Messeger translates client's messages into batch stream requests. + * + * @see Message + */ + private final class Messeger extends DelegatingStreamObserver { + private Messeger(StreamObserver delegate) { super(delegate); } @Override - public void onNext(StreamMessage message) { + public void onNext(Message message) { WeaviateProtoBatch.BatchStreamRequest.Builder builder = WeaviateProtoBatch.BatchStreamRequest.newBuilder(); message.appendTo(builder); delegate.onNext(builder.build()); } } - private final class EventHandler extends DelegatingStreamObserver { - private EventHandler(StreamObserver delegate) { + /** + * Eventer translates server replies into events. + * + * @see Event + */ + private final class Eventer extends DelegatingStreamObserver { + private Eventer(StreamObserver delegate) { super(delegate); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java new file mode 100644 index 000000000..d62856c24 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java @@ -0,0 +1,33 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.OptionalInt; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; + +public class WeaviateBatchClient { + private final CollectionHandleDefaults defaults; + private final GrpcTransport grpcTransport; + + public WeaviateBatchClient(GrpcTransport grpcTransport, CollectionHandleDefaults defaults) { + this.defaults = requireNonNull(defaults, "defaults is null"); + this.grpcTransport = requireNonNull(grpcTransport, "grpcTransport is null"); + } + + /** Copy constructor with new defaults. */ + public WeaviateBatchClient(WeaviateBatchClient c, CollectionHandleDefaults defaults) { + this.defaults = requireNonNull(defaults, "defaults is null"); + this.grpcTransport = c.grpcTransport; + } + + public BatchContext start() { + OptionalInt maxSizeBytes = grpcTransport.maxMessageSizeBytes(); + if (maxSizeBytes.isEmpty()) { + throw new IllegalStateException("Server must have grpcMaxMessageSize configured to use server-side batching"); + } + StreamFactory streamFactory = new TranslatingStreamFactory(grpcTransport::createStream); + return new BatchContext<>(streamFactory, maxSizeBytes.getAsInt(), defaults.consistencyLevel()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java index c7497ec64..19613dff2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java @@ -35,7 +35,7 @@ public WeaviateDataClient( this.defaults = defaults; } - /** Copy constructor that updates the {@link #query} to use new defaults. */ + /** Copy constructor with new defaults. */ public WeaviateDataClient(WeaviateDataClient c, CollectionHandleDefaults defaults) { this.restTransport = c.restTransport; this.grpcTransport = c.grpcTransport; diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index d12255d22..55f428bfb 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -1,6 +1,10 @@ package io.weaviate.client6.v1.internal.grpc; +import static java.util.Objects.requireNonNull; + +import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; @@ -16,57 +20,57 @@ import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; import io.grpc.stub.AbstractStub; import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; import io.weaviate.client6.v1.api.WeaviateApiException; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest; public final class DefaultGrpcTransport implements GrpcTransport { + /** + * ListenableFuture callbacks are executed + * in the same thread they are called from. + */ + private static final Executor FUTURE_CALLBACK_EXECUTOR = Runnable::run; + + private final GrpcChannelOptions transportOptions; private final ManagedChannel channel; private final WeaviateBlockingStub blockingStub; private final WeaviateFutureStub futureStub; - private final GrpcChannelOptions transportOptions; - private TokenCallCredentials callCredentials; public DefaultGrpcTransport(GrpcChannelOptions transportOptions) { - this.transportOptions = transportOptions; - this.channel = buildChannel(transportOptions); - - var blockingStub = WeaviateGrpc.newBlockingStub(channel) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); - - var futureStub = WeaviateGrpc.newFutureStub(channel) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); - - if (transportOptions.maxMessageSize() != null) { - var max = transportOptions.maxMessageSize(); - blockingStub = blockingStub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max); - futureStub = futureStub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max); - } + requireNonNull(transportOptions, "transportOptions is null"); + this.transportOptions = transportOptions; if (transportOptions.tokenProvider() != null) { this.callCredentials = new TokenCallCredentials(transportOptions.tokenProvider()); - blockingStub = blockingStub.withCallCredentials(callCredentials); - futureStub = futureStub.withCallCredentials(callCredentials); } - this.blockingStub = blockingStub; - this.futureStub = futureStub; + this.channel = buildChannel(transportOptions); + this.blockingStub = configure(WeaviateGrpc.newBlockingStub(channel)); + this.futureStub = configure(WeaviateGrpc.newFutureStub(channel)); } private > StubT applyTimeout(StubT stub, Rpc rpc) { if (transportOptions.timeout() == null) { return stub; } - var timeout = rpc.isInsert() + int timeout = rpc.isInsert() ? transportOptions.timeout().insertSeconds() : transportOptions.timeout().querySeconds(); return stub.withDeadlineAfter(timeout, TimeUnit.SECONDS); } + @Override + public OptionalInt maxMessageSizeBytes() { + return transportOptions.maxMessageSize(); + } + @Override public ResponseT performRequest(RequestT request, Rpc rpc) { @@ -96,7 +100,9 @@ public CompletableFuture perf * reusing the thread in which the original future is completed. */ private static final CompletableFuture toCompletableFuture(ListenableFuture listenable) { - var completable = new CompletableFuture(); + requireNonNull(listenable, "listenable is null"); + + CompletableFuture completable = new CompletableFuture<>(); Futures.addCallback(listenable, new FutureCallback() { @Override @@ -113,13 +119,14 @@ public void onFailure(Throwable t) { completable.completeExceptionally(t); } - }, Runnable::run); + }, FUTURE_CALLBACK_EXECUTOR); return completable; } private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) { - var channel = NettyChannelBuilder.forAddress(transportOptions.host(), transportOptions.port()); + requireNonNull(transportOptions, "transportOptions is null"); + NettyChannelBuilder channel = NettyChannelBuilder.forAddress(transportOptions.host(), transportOptions.port()); if (transportOptions.isSecure()) { channel.useTransportSecurity(); } else { @@ -140,10 +147,29 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) } channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); - return channel.build(); } + @Override + public StreamObserver createStream(StreamObserver recv) { + return configure(WeaviateGrpc.newStub(channel)).batchStream(recv); + } + + /** Apply common configuration to a stub. */ + private > S configure(S stub) { + requireNonNull(stub, "stub is null"); + + stub = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); + if (transportOptions.maxMessageSize().isPresent()) { + int max = transportOptions.maxMessageSize().getAsInt(); + stub = stub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max); + } + if (callCredentials != null) { + stub = stub.withCallCredentials(callCredentials); + } + return stub; + } + @Override public void close() throws Exception { channel.shutdown(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java index 5e4453d7f..96366cb5f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.internal.grpc; import java.util.Map; +import java.util.OptionalInt; import javax.net.ssl.TrustManagerFactory; @@ -10,7 +11,7 @@ import io.weaviate.client6.v1.internal.TransportOptions; public class GrpcChannelOptions extends TransportOptions { - private final Integer maxMessageSize; + private final OptionalInt maxMessageSize; public GrpcChannelOptions(String scheme, String host, int port, Map headers, TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) { @@ -18,17 +19,18 @@ public GrpcChannelOptions(String scheme, String host, int port, Map ResponseT performRequest(RequestT request, - Rpc rpc); + ResponseT performRequest(RequestT request, + Rpc rpc); + + CompletableFuture performRequestAsync(RequestT request, + Rpc rpc); + + /** + * Create stream for batch insertion. + * + * @apiNote Batch insertion is presently the only operation performed over a + * StreamStream connection, which is why we do not parametrize this + * method. + */ + StreamObserver createStream( + StreamObserver recv); - CompletableFuture performRequestAsync(RequestT request, - Rpc rpc); + /** + * Maximum inbound/outbound message size supported by the underlying channel. + */ + OptionalInt maxMessageSizeBytes(); } diff --git a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java index ebea2fea7..5778fbc6d 100644 --- a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java +++ b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java @@ -3,14 +3,18 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.util.JsonFormat; +import io.grpc.stub.StreamObserver; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.grpc.Rpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest; public class MockGrpcTransport implements GrpcTransport { @@ -57,4 +61,16 @@ public CompletableFuture perf @Override public void close() throws IOException { } + + @Override + public StreamObserver createStream(StreamObserver recv) { + // TODO(dyma): implement for tests + throw new UnsupportedOperationException("Unimplemented method 'createStream'"); + } + + @Override + public OptionalInt maxMessageSizeBytes() { + // TODO(dyma): implement for tests + throw new UnsupportedOperationException("Unimplemented method 'maxMessageSizeBytes'"); + } } From 5af3cb52d5bd553cfd265720605f634167dc249f Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 11 Feb 2026 01:53:14 +0100 Subject: [PATCH 04/22] feat(batch): implement state transitions and event handling Added sketch implementation for the Send routine. Still raw and riddled with comments, but it's a good start. --- .../v1/api/collections/CollectionHandle.java | 4 +- .../collections/CollectionHandleAsync.java | 2 +- .../collections/CollectionHandleDefaults.java | 16 +- .../aggregate/AggregateRequest.java | 4 +- .../api/collections/batch/BatchContext.java | 384 ++++++++++++++++-- .../v1/api/collections/batch/State.java | 29 +- .../v1/api/collections/batch/TaskHandle.java | 12 +- .../batch/WeaviateBatchClient.java | 15 +- .../collections/config/GetShardsRequest.java | 4 +- .../api/collections/data/BatchReference.java | 13 +- .../collections/data/DeleteManyRequest.java | 4 +- .../collections/data/InsertManyRequest.java | 23 +- .../collections/data/InsertObjectRequest.java | 2 +- .../data/ReplaceObjectRequest.java | 2 +- .../collections/data/UpdateObjectRequest.java | 2 +- .../api/collections/query/QueryRequest.java | 5 +- .../api/collections/CollectionHandleTest.java | 6 +- 17 files changed, 424 insertions(+), 103 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java index 7bbc94c19..b49ec987a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -39,7 +39,7 @@ public CollectionHandle( this.query = new WeaviateQueryClient<>(collection, grpcTransport, defaults); this.generate = new WeaviateGenerateClient<>(collection, grpcTransport, defaults); this.data = new WeaviateDataClient<>(collection, restTransport, grpcTransport, defaults); - this.batch = new WeaviateBatchClient<>(grpcTransport, defaults); + this.batch = new WeaviateBatchClient<>(grpcTransport, collection, defaults); this.defaults = defaults; this.tenants = new WeaviateTenantsClient(collection, restTransport, grpcTransport); @@ -127,7 +127,7 @@ public CollectionHandle withConsistencyLevel(ConsistencyLevel consi } /** Default tenant for requests. */ - public String tenant() { + public Optional tenant() { return defaults.tenant(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java index 0b29d2c82..14c551d18 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -102,7 +102,7 @@ public CollectionHandleAsync withConsistencyLevel(ConsistencyLevel } /** Default tenant for requests. */ - public String tenant() { + public Optional tenant() { return defaults.tenant(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java index f89e9f6f1..d9026805a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java @@ -11,7 +11,7 @@ import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.internal.ObjectBuilder; -public record CollectionHandleDefaults(Optional consistencyLevel, String tenant) { +public record CollectionHandleDefaults(Optional consistencyLevel, Optional tenant) { /** * Set default values for query / aggregation requests. * @@ -36,7 +36,7 @@ public static Function> none() } public CollectionHandleDefaults(Builder builder) { - this(Optional.of(builder.consistencyLevel), builder.tenant); + this(Optional.of(builder.consistencyLevel), Optional.of(builder.tenant)); } public static final class Builder implements ObjectBuilder { @@ -63,16 +63,12 @@ public CollectionHandleDefaults build() { /** Serialize default values to a URL query. */ public Map queryParameters() { - if (consistencyLevel.isEmpty() && tenant == null) { + if (consistencyLevel.isEmpty() && tenant.isEmpty()) { return Collections.emptyMap(); } - var query = new HashMap(); - if (consistencyLevel.isPresent()) { - query.put("consistency_level", consistencyLevel.get()); - } - if (tenant != null) { - query.put("tenant", tenant); - } + Map query = new HashMap(); + consistencyLevel.ifPresent(v -> query.put("consistency_level", v)); + tenant.ifPresent(v -> query.put("tenant", v)); return query; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java index fa8290f64..7d048937a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java @@ -25,9 +25,7 @@ static Rpc { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 167087fa2..175b4c4dc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -4,11 +4,18 @@ import java.io.Closeable; import java.io.IOException; -import java.util.Optional; +import java.util.Arrays; +import java.util.Collection; +import java.util.EnumSet; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -16,8 +23,11 @@ import javax.annotation.concurrent.GuardedBy; import io.grpc.stub.StreamObserver; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.WeaviateObject; -import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; +import io.weaviate.client6.v1.api.collections.data.BatchReference; +import io.weaviate.client6.v1.api.collections.data.InsertManyRequest; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; /** * BatchContext stores the state of an active batch process @@ -35,8 +45,11 @@ public final class BatchContext implements Closeable { private final int DEFAULT_BATCH_SIZE = 1000; private final int DEFAULT_QUEUE_SIZE = 100; + private final CollectionDescriptor collectionDescriptor; + private final CollectionHandleDefaults collectionHandleDefaults; + + /** Stream factory creates new streams. */ private final StreamFactory streamFactory; - private final Optional consistencyLevel; /** * Queue publishes insert tasks from the main thread to the "sender". @@ -79,24 +92,83 @@ public final class BatchContext implements Closeable { /** stateChanged notifies threads about a state transition. */ private final Condition stateChanged = lock.newCondition(); + /** + * Internal execution service. It's lifecycle is bound to that of the + * BatchContext: it's started when the context is initialized + * and shutdown on {@link #close}. + */ + private final ExecutorService exec = Executors.newSingleThreadExecutor(); + + /** Service thread pool for OOM timer. */ + private final ScheduledExecutorService scheduledExec = Executors.newScheduledThreadPool(1); + + /** + * Currently open stream. This will be created on {@link #start}. + * Other threads MAY use stream but MUST NOT update this field on their own. + */ + private volatile StreamObserver messages; + + private volatile StreamObserver events; + BatchContext( StreamFactory streamFactory, int maxSizeBytes, - Optional consistencyLevel) { + CollectionDescriptor collectionDescriptor, + CollectionHandleDefaults collectionHandleDefaults) { this.streamFactory = requireNonNull(streamFactory, "streamFactory is null"); - this.consistencyLevel = requireNonNull(consistencyLevel, "consistencyLevel is null"); + this.collectionDescriptor = requireNonNull(collectionDescriptor, "collectionDescriptor is null"); + this.collectionHandleDefaults = requireNonNull(collectionHandleDefaults, "collectionHandleDefaults is null"); this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes); } + /** Add {@link WeaviateObject} to the batch. */ + public TaskHandle add(WeaviateObject object) throws InterruptedException { + TaskHandle handle = new TaskHandle( + object, + InsertManyRequest.buildObject(object, collectionDescriptor, collectionHandleDefaults)); + return add(handle); + } + + /** Add {@link BatchReference} to the batch. */ + public TaskHandle add(BatchReference reference) throws InterruptedException { + TaskHandle handle = new TaskHandle( + reference, + InsertManyRequest.buildReference(reference, collectionHandleDefaults.tenant())); + return add(handle); + } + void start() { + Recv recv = new Recv(); + messages = streamFactory.createStream(recv); + events = recv; + Send send = new Send(); + exec.execute(send); } void reconnect() { } + /** + * Retry a task. + * + * BatchContext does not impose any limit on the number of times a task can + * be retried -- it is up to the user to implement an appropriate retry policy. + * + * @see TaskHandle#timesRetried + */ + public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException { + return add(taskHandle.retry()); + } + + @Override + public void close() throws IOException { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'close'"); + } + /** Set the new state and notify awaiting threads. */ void setState(State nextState) { requireNonNull(nextState, "nextState is null"); @@ -110,15 +182,6 @@ void setState(State nextState) { } } - boolean canSend() { - lock.lock(); - try { - return state.canSend(); - } finally { - lock.unlock(); - } - } - /** onEvent delegates event handling to {@link #state} */ void onEvent(Event event) throws InterruptedException { lock.lock(); @@ -129,52 +192,297 @@ void onEvent(Event event) throws InterruptedException { } } - /** Add {@link WeaviateObject} to the batch. */ - public TaskHandle add() { - return null; + private TaskHandle add(final TaskHandle taskHandle) throws InterruptedException { + // TODO(dyma): check that we haven't closed the stream on our end yet + // probably with some state.isClosed() or something + // TODO(dyma): check that wip doesn't have that ID yet, otherwise + // we can lose some data (results) + queue.put(taskHandle); + return taskHandle; } - public TaskHandle retry(TaskHandle taskHandle) { - return null; - } + private final class Send implements Runnable { + @Override + public void run() { + try { + // trySend exists normally + trySend(); + messages.onCompleted(); + return; + } catch (InterruptedException ignored) { + // TODO(dyma): interrupted (whether through the exception + // by breaking the while loop. Restore the interrupted status + // and update the state + } + } + + private void trySend() throws InterruptedException { + try { + while (!Thread.currentThread().isInterrupted()) { + // if batch is full: + // -> if the stream is closed / status is error (error) return + // else send and await ack + // + // take the next item in the queue + // -> if POISON: drain the batch, call onComplete, return + // + // add to batch - private final class Sender implements Runnable { - private final StreamObserver stream; + // TODO(dyma): check that the batch is + if (batch.isFull()) { + send(); + } - private Sender(StreamObserver stream) { - this.stream = requireNonNull(stream, "stream is null"); + TaskHandle task = queue.take(); + + if (task == TaskHandle.POISON) { + drain(); + return; + } + + Data data = task.data(); + if (!batch.add(data)) { + // FIXME(dyma): once we've removed a task from the queue, we must + // ensure that it makes it's way to the batch, otherwise we lose + // that task. Here, for example, send() can be interrupted in which + // case the task is lost. + // How do we fix this? We cannot ignore the interrupt, because + // interrupted send() means the batch was not acked, so it will + // not accept any new items. + // + // Maybe! batch.add should put the data in the backlog if it couldn't + // fit it in the buffer!!! + // Yes!!!!!!! The backlog is not limited is size, so it will fit any + // data that does not exceed maxGrpcMessageSize. We wouldn't need to + // do a second pass ourselves. + send(); + boolean ok = batch.add(data); + assert ok : "batch.add must succeed after send"; + } + + // TODO(dyma): check that the previous is null, + // we should've checked for that upstream in add(). + TaskHandle existing = wip.put(task.id(), task); + assert existing == null : "duplicate tasks in progress, id=" + existing.id(); + } + } catch (DataTooBigException e) { + // TODO(dyma): fail + } } - @Override - public void run() { - throw new UnsupportedOperationException("implement!"); + private void send() throws InterruptedException { + // This will stop sending as soon as we get the batch not a "not full" state. + // The reason we do that is to account for the backlog, which might re-fill + // the batch's buffer after .clear(). + while (batch.isFull()) { + flush(); + } + assert !batch.isFull() : "batch is full after send"; + } + + private void drain() throws InterruptedException { + // This will send until ALL items in the batch have been sent. + while (!batch.isEmpty()) { + flush(); + } + assert batch.isEmpty() : "batch not empty after drain"; + } + + private void flush() throws InterruptedException { + // TODO(dyma): if we're in OOM / ServerShuttingDown state, then we known there + // isn't any reason to keep waiting for the acks. However, we cannot exit + // without taking a poison pill from the queue, because this risks blocking the + // producer thread. + // So we patiently wait, relying purely on the 2 booleans: canSend and + // "isAcked". maybe not "isAcked" but "canAdd" / "canAccept"? + + // FIXME(dyma): draining the batch is not a good idea because the backlog + // is likely smaller that the maxSize, so we'd sending half-empty batches. + + awaitCanSend(); + messages.onNext(batch.prepare()); + setState(AWAIT_ACKS); + awaitAcked(); // TODO(dyma): rename canTake into something like awaitAcked(); + // the method can be called boolean isInFlight(); + } + + private void awaitCanSend() throws InterruptedException { + lock.lock(); + try { + while (!state.canSend()) { + stateChanged.await(); + } + } finally { + lock.unlock(); + } + } + + // TODO(dyma): the semantics of "canTake" is rather "can I put more data in the + // batch", even more precisely -- "is the batch still in-flight or is it open?" + private void awaitAcked() throws InterruptedException { + lock.lock(); + try { + while (!state.canTake()) { + stateChanged.await(); + } + } finally { + // Not a good assertion: batch could've been re-populated from the backlog. + // assert !batch.isFull() : "take allowed with full batch"; + lock.unlock(); + } } } private final class Recv implements StreamObserver { + @Override + public void onNext(Event event) { + try { + BatchContext.this.onEvent(event); + } catch (InterruptedException e) { + // TODO(dyma): cancel the RPC (req.onError()) + } catch (Exception e) { + // TODO(dyma): cancel with + } + } + @Override public void onCompleted() { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'onCompleted'"); + // TODO(dyma): server closed its side of the stream successfully + // Maybe log, but there's nothing that we need to do here + // This is the EOF that the protocol document is talking about } @Override public void onError(Throwable arg0) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'onError'"); + // TODO(dyma): if we did req.onError(), then the error can be ignored + // The exception should be set somewhere so all threads can observe it + } + } + + final State CLOSED = new BaseState(); + final State AWAIT_STARTED = new BaseState(BaseState.Action.TAKE) { + @Override + public void onEvent(Event event) { + if (requireNonNull(event, "event is null") == Event.STARTED) { + setState(ACTIVE); + return; + } + super.onEvent(event); + } + }; + final State ACTIVE = new BaseState(BaseState.Action.TAKE, BaseState.Action.SEND); + final State AWAIT_ACKS = new BaseState() { + @Override + public void onEvent(Event event) { + requireNonNull(event, "event is null"); + + if (event instanceof Event.Acks acks) { + Collection remaining = batch.clear(); + if (!remaining.isEmpty()) { + // TODO(dyma): throw an exception -- this is bad + } + // TODO(dyma): should we check if wip contains ID? + // TODO(dyma): do we need to synchronize here? I don't think so... + acks.acked().forEach(ack -> wip.get(ack).setAcked()); + setState(ACTIVE); + } else if (event == Event.OOM) { + int delaySeconds = 300; + setState(new Oom(delaySeconds)); + } else { + super.onEvent(event); + } + } + }; + + private class BaseState implements State { + private final EnumSet permitted; + + enum Action { + TAKE, SEND; + } + + protected BaseState(Action... actions) { + this.permitted = EnumSet.copyOf(Arrays.asList(requireNonNull(actions, "actions is null"))); + } + + @Override + public boolean canSend() { + return permitted.contains(Action.SEND); + } + + @Override + public boolean canTake() { + return permitted.contains(Action.TAKE); + } + + @Override + public void onEvent(Event event) { + requireNonNull(event, "event is null"); + + if (event instanceof Event.Results results) { + results.successful().forEach(id -> wip.get(id).setSuccess()); + results.errors().forEach((id, error) -> wip.get(id).setError(error)); + } else if (event instanceof Event.Backoff backoff) { + batch.setMaxSize(backoff.maxSize()); + } else if (event == Event.SHUTTING_DOWN) { + setState(new ServerShuttingDown(this)); + } else { + throw new IllegalStateException("cannot handle " + event.getClass()); + } + } + } + + private class Oom extends BaseState { + private final ScheduledFuture shutdown; + + private Oom(long delaySeconds) { + super(); + this.shutdown = scheduledExec.schedule(this::initiateShutdown, delaySeconds, TimeUnit.SECONDS); + } + + private void initiateShutdown() { + if (Thread.currentThread().isInterrupted()) { + return; + } + events.onNext(Event.SHUTTING_DOWN); + events.onNext(Event.SHUTDOWN); } @Override - public void onNext(Event arg0) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'onNext'"); + public void onEvent(Event event) { + if (requireNonNull(event, "event is null") != Event.SHUTTING_DOWN) { + throw new IllegalStateException("Expected OOM to be followed by ShuttingDown"); + } + + shutdown.cancel(true); + setState(new ServerShuttingDown(this)); } } - @Override - public void close() throws IOException { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'close'"); + private class ServerShuttingDown implements State { + private final boolean canTake; + + private ServerShuttingDown(State previous) { + this.canTake = requireNonNull(previous, "previous is null").getClass() == Oom.class; + } + + @Override + public boolean canTake() { + return canTake; + } + + @Override + public boolean canSend() { + return false; + } + + @Override + public void onEvent(Event event) throws InterruptedException { + if (requireNonNull(event, "event is null") != Event.SHUTDOWN) { + throw new IllegalStateException("Expected ShuttingDown to be followed by Shutdown"); + } + setState(CLOSED); + } } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java index 52a2adab6..b16a2be94 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java @@ -1,7 +1,5 @@ package io.weaviate.client6.v1.api.collections.batch; -import static java.util.Objects.requireNonNull; - interface State { /** * canSend returns a boolean indicating if sending @@ -9,6 +7,12 @@ interface State { */ boolean canSend(); + /** + * canTake returns a boolean indicating if accepting + * more items into the batch is allowed in this state. + */ + boolean canTake(); + /** * onEvent handles incoming events; these can be generated by the server * or by a different part of the program -- the {@link State} MUST NOT @@ -22,25 +26,4 @@ interface State { * a separate process, e.g. the OOM timer. */ void onEvent(Event event) throws InterruptedException; - - abstract class BaseState implements State { - @Override - public void onEvent(Event event) { - requireNonNull(event, "event is null"); - - if (event instanceof Event.Acks acks) { - - } else if (event instanceof Event.Results results) { - - } else if (event instanceof Event.Backoff backoff) { - - } else if (event == Event.OOM) { - - } else if (event == Event.SHUTTING_DOWN) { - - } else { - throw new IllegalStateException("cannot handle " + event.getClass()); - } - } - } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java index e7c6799cc..7752246e6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java @@ -11,7 +11,7 @@ import com.google.protobuf.GeneratedMessageV3; import io.weaviate.client6.v1.api.collections.WeaviateObject; -import io.weaviate.client6.v1.api.collections.data.ObjectReference; +import io.weaviate.client6.v1.api.collections.data.BatchReference; @ThreadSafe @SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 @@ -59,9 +59,9 @@ private TaskHandle(Data data, int retries) { this(new Data(object, object.uuid(), data, Data.Type.OBJECT), 0); } - /** Constructor for {@link ObjectReference}. */ - TaskHandle(ObjectReference reference, GeneratedMessage.ExtendableMessage data) { - this(new Data(reference, reference.beacon(), data, Data.Type.REFERENCE), 0); + /** Constructor for {@link BatchReference}. */ + TaskHandle(BatchReference reference, GeneratedMessage.ExtendableMessage data) { + this(new Data(reference, reference.target().beacon(), data, Data.Type.REFERENCE), 0); } /** @@ -93,6 +93,10 @@ String id() { return data.id(); } + Data data() { + return data; + } + /** Set the {@link #acked} flag. */ void setAcked() { acked.complete(null); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java index d62856c24..551d3adcb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java @@ -6,19 +6,26 @@ import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; public class WeaviateBatchClient { private final CollectionHandleDefaults defaults; + private final CollectionDescriptor collectionDescriptor; private final GrpcTransport grpcTransport; - public WeaviateBatchClient(GrpcTransport grpcTransport, CollectionHandleDefaults defaults) { + public WeaviateBatchClient( + GrpcTransport grpcTransport, + CollectionDescriptor collectionDescriptor, + CollectionHandleDefaults defaults) { this.defaults = requireNonNull(defaults, "defaults is null"); + this.collectionDescriptor = requireNonNull(collectionDescriptor, "collectionDescriptor is null"); this.grpcTransport = requireNonNull(grpcTransport, "grpcTransport is null"); } /** Copy constructor with new defaults. */ public WeaviateBatchClient(WeaviateBatchClient c, CollectionHandleDefaults defaults) { this.defaults = requireNonNull(defaults, "defaults is null"); + this.collectionDescriptor = c.collectionDescriptor; this.grpcTransport = c.grpcTransport; } @@ -28,6 +35,10 @@ public BatchContext start() { throw new IllegalStateException("Server must have grpcMaxMessageSize configured to use server-side batching"); } StreamFactory streamFactory = new TranslatingStreamFactory(grpcTransport::createStream); - return new BatchContext<>(streamFactory, maxSizeBytes.getAsInt(), defaults.consistencyLevel()); + return new BatchContext<>( + streamFactory, + maxSizeBytes.getAsInt(), + collectionDescriptor, + defaults); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java index 5237fed26..c27c13aa7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java @@ -21,8 +21,8 @@ public static final Endpoint> endpoint( return SimpleEndpoint.noBody( request -> "GET", request -> "/schema/" + collection.collectionName() + "/shards", - request -> defaults.tenant() != null - ? Map.of("tenant", defaults.tenant()) + request -> defaults.tenant().isPresent() + ? Map.of("tenant", defaults.tenant().get()) : Collections.emptyMap(), (statusCode, response) -> (List) JSON.deserialize(response, TypeToken.getParameterized( List.class, Shard.class))); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java index a83652be5..503ec59d1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java @@ -1,5 +1,7 @@ package io.weaviate.client6.v1.api.collections.data; +import static java.util.Objects.requireNonNull; + import java.io.IOException; import java.util.Arrays; @@ -9,7 +11,14 @@ import io.weaviate.client6.v1.api.collections.WeaviateObject; -public record BatchReference(String fromCollection, String fromProperty, String fromUuid, ObjectReference reference) { +public record BatchReference(String fromCollection, String fromProperty, String fromUuid, ObjectReference target) { + + public BatchReference { + requireNonNull(fromCollection, "fromCollection is null"); + requireNonNull(fromProperty, "fromProperty is null"); + requireNonNull(fromUuid, "fromUuid is null"); + requireNonNull(target, "target is null"); + } public static BatchReference[] objects(WeaviateObject fromObject, String fromProperty, WeaviateObject... toObjects) { @@ -39,7 +48,7 @@ public void write(JsonWriter out, BatchReference value) throws IOException { out.value(ObjectReference.toBeacon(value.fromCollection, value.fromProperty, value.fromUuid)); out.name("to"); - out.value(ObjectReference.toBeacon(value.reference.collection(), value.reference.uuid())); + out.value(ObjectReference.toBeacon(value.target.collection(), value.target.uuid())); out.endObject(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java index 601ea072d..e60133957 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java @@ -29,9 +29,7 @@ public static Rpc WeaviateProtoBatch.BatchObject buildObject( }).toList(); object.addAllVectors(vectors); } - if (defaults.tenant() != null) { - object.setTenant(defaults.tenant()); - } + defaults.tenant().ifPresent(object::setTenant); var singleRef = new ArrayList(); var multiRef = new ArrayList(); @@ -331,4 +332,20 @@ private static com.google.protobuf.Struct marshalStruct(Map prop }); return struct.build(); } + + public static WeaviateProtoBatch.BatchReference buildReference(BatchReference reference, Optional tenant) { + requireNonNull(reference, "reference is null"); + WeaviateProtoBatch.BatchReference.Builder builder = WeaviateProtoBatch.BatchReference.newBuilder(); + builder + .setName(reference.fromProperty()) + .setFromCollection(reference.fromCollection()) + .setFromUuid(reference.fromUuid()) + .setToUuid(reference.target().uuid()); + + if (reference.target().collection() != null) { + builder.setToCollection(reference.target().collection()); + } + tenant.ifPresent(t -> builder.setTenant(t)); + return builder.build(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java index 86840ed5d..3bfcdf52d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java @@ -34,7 +34,7 @@ public static final Endpoint, Wea new WeaviateObject<>( request.object.uuid(), collection.collectionName(), - defaults.tenant(), + defaults.tenant().get(), request.object.properties(), request.object.vectors(), request.object.createdAt(), diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java index 81dbb2428..45419afff 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java @@ -33,7 +33,7 @@ static final Endpoint, Void> end new WeaviateObject<>( request.object.uuid(), collection.collectionName(), - defaults.tenant(), + defaults.tenant().get(), request.object.properties(), request.object.vectors(), request.object.createdAt(), diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java index cf0451d0d..65a1a66f1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java @@ -33,7 +33,7 @@ static final Endpoint, Void> endp new WeaviateObject<>( request.object.uuid(), collection.collectionName(), - defaults.tenant(), + defaults.tenant().get(), request.object.properties(), request.object.vectors(), request.object.createdAt(), diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index e723c2c03..8bb0db91b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -37,10 +37,7 @@ public static WeaviateProtoSearchGet.SearchRequest marshal( } request.operator.appendTo(message); - if (defaults.tenant() != null) { - message.setTenant(defaults.tenant()); - } - + defaults.tenant().ifPresent(message::setTenant); if (defaults.consistencyLevel().isPresent()) { defaults.consistencyLevel().get().appendTo(message); } diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java index 9ff578cfe..50286503a 100644 --- a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java @@ -139,10 +139,10 @@ public void test_collectionHandleDefaults_rest(String __, switch (tenantLoc) { case QUERY: - Assertions.assertThat(query).containsEntry("tenant", defaults.tenant()); + Assertions.assertThat(query).containsEntry("tenant", defaults.tenant().get()); break; case BODY: - assertJsonHasValue(body, "tenant", defaults.tenant()); + assertJsonHasValue(body, "tenant", defaults.tenant().get()); } }); } @@ -219,7 +219,7 @@ public void test_defaultTenant_getShards() throws IOException { // Assert rest.assertNext((method, requestUrl, body, query) -> { - Assertions.assertThat(query).containsEntry("tenant", defaults.tenant()); + Assertions.assertThat(query).containsEntry("tenant", defaults.tenant().get()); }); } From 4dad5e0ca6f7744a89167a9b7cef4e13f9115d97 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 11 Feb 2026 01:56:30 +0100 Subject: [PATCH 05/22] chore(batch): remove old comment --- .../client6/v1/api/collections/batch/Batch.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java index c4cfc479f..968f8503e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java @@ -18,19 +18,6 @@ import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; -// assert maxSize > 0 : "non-positive maxSize"; -// assert maxSizeBytes > 0 : "non-positive maxSizeBytes"; -// assert sizeBytes >= 0 : "negative sizeBytes"; -// assert buffer.size() <= maxSize : "buffer exceeds maxSize"; -// assert sizeBytes <= maxSizeBytes : "message exceeds maxSizeBytes"; -// if (buffer.size() < maxSize) { -// assert backlog.isEmpty() : "backlog not empty when buffer not full"; -// } -// if (buffer.isEmpty()) { -// assert sizeBytes == 0 : "sizeBytes must be 0 when buffer is empty"; -// } -// assert pendingMaxSize != null : "pending max size is null"; - /** * Batch can be in either of 2 states: *

    From 7e510dde6ee08a9ced3babcb5eab40b8e3fb5a2a Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Fri, 13 Feb 2026 17:13:00 +0100 Subject: [PATCH 06/22] feat(batch): implement graceful shutdown and abort --- .../v1/api/collections/batch/Batch.java | 37 +- .../api/collections/batch/BatchContext.java | 414 +++++++++++++----- .../batch/DataTooBigException.java | 9 +- .../batch/DuplicateTaskException.java | 23 + .../v1/api/collections/batch/Event.java | 25 +- .../batch/ProtocolViolationException.java | 48 ++ .../v1/api/collections/batch/State.java | 17 +- .../v1/api/collections/batch/TaskHandle.java | 14 + 8 files changed, 454 insertions(+), 133 deletions(-) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java index 968f8503e..6cb9d6051 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java @@ -188,6 +188,8 @@ synchronized void setMaxSize(int maxSizeNew) { extra.remove(); } // Reverse the backlog to restore the FIFO order. + // FIXME(dyma): this assumes setMaxSize is called on an empty backlog, + // but that's is simply not true. Collections.reverse(backlog); } finally { checkInvariants(); @@ -197,16 +199,22 @@ synchronized void setMaxSize(int maxSizeNew) { /** * Add a data item to the batch. * + *

    + * We want to guarantee that, once a work item has been taken from the queue, + * it's going to be eventually executed. Because we cannot know if an item + * will overflow the batch before it's removed from the queue, the simplest + * and safest way to deal with it is to allow {@link Batch} to put + * the overflowing item in the {@link #backlog}. The batch is considered + * full after that and will not accept any more items until it's cleared. * * @throws DataTooBigException If the data exceeds the maximum * possible batch size. * @throws IllegalStateException If called on an "in-flight" batch. * @see #prepare * @see #inFlight - * - * @return Boolean indicating if the item has been accepted. + * @see #clear */ - synchronized boolean add(Data data) throws IllegalStateException, DataTooBigException { + synchronized void add(Data data) throws IllegalStateException, DataTooBigException { requireNonNull(data, "data is null"); checkInvariants(); @@ -214,14 +222,23 @@ synchronized boolean add(Data data) throws IllegalStateException, DataTooBigExce if (inFlight) { throw new IllegalStateException("Batch is in-flight"); } - if (data.sizeBytes() > maxSizeBytes - sizeBytes) { - if (isEmpty()) { - throw new DataTooBigException(data, maxSizeBytes); - } - return false; + long remainingBytes = maxSizeBytes - sizeBytes; + if (data.sizeBytes() <= remainingBytes) { + addSafe(data); + return; + } + if (isEmpty()) { + throw new DataTooBigException(data, maxSizeBytes); } - addSafe(data); - return true; + // One of the class's invariants is that the backlog must not contain + // any items unless the buffer is full. In case this item overflows + // the buffer, we put it in the backlog, but pretend the maxSizeBytes + // has been reached to satisfy the invariant. + // This doubles as a safeguard to ensure the caller cannot add any + // more items to the batch before flushing it. + backlog.add(data); + sizeBytes += remainingBytes; + assert isFull() : "batch must be full after an overflow"; } finally { checkInvariants(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 175b4c4dc..b896f6b9c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -7,10 +7,15 @@ import java.util.Arrays; import java.util.Collection; import java.util.EnumSet; +import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -19,12 +24,15 @@ import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BiConsumer; import javax.annotation.concurrent.GuardedBy; +import io.grpc.Status; import io.grpc.stub.StreamObserver; import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.batch.Event.RpcError; import io.weaviate.client6.v1.api.collections.data.BatchReference; import io.weaviate.client6.v1.api.collections.data.InsertManyRequest; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; @@ -96,20 +104,50 @@ public final class BatchContext implements Closeable { * Internal execution service. It's lifecycle is bound to that of the * BatchContext: it's started when the context is initialized * and shutdown on {@link #close}. + * + *

    + * In the event of abrupt stream termination ({@link Recv#onError} is called), + * the "recv" thread MAY shutdown this service in order to interrupt the "send" + * thread; the latter may be blocked on {@link Send#awaitCanSend} or + * {@link Send#awaitCanPrepareNext}. */ - private final ExecutorService exec = Executors.newSingleThreadExecutor(); + private final ExecutorService sendExec = Executors.newSingleThreadExecutor(); - /** Service thread pool for OOM timer. */ + /** + * Scheduled thread pool for OOM timer. + * + * @see Oom + */ private final ScheduledExecutorService scheduledExec = Executors.newScheduledThreadPool(1); + /** Service executor for polling {@link #workers} status before closing. */ + private final ExecutorService closeExec = Executors.newSingleThreadExecutor(); + /** - * Currently open stream. This will be created on {@link #start}. + * Client-side part of the current stream, created on {@link #start}. * Other threads MAY use stream but MUST NOT update this field on their own. */ private volatile StreamObserver messages; + /** + * Server-side part of the current stream, created on {@link #start}. + * Other threads MAY use stream but MUST NOT update this field on their own. + */ private volatile StreamObserver events; + /** + * Latch reaches zero once both "send" (client side) and "recv" (server side) + * parts of the stream have closed. After a {@link reconnect}, the latch is + * reset. + */ + private volatile CountDownLatch workers; + + /** done completes the stream. */ + private final CompletableFuture closed = new CompletableFuture<>(); + + /** Thread which created the BatchContext. */ + private final Thread parent = Thread.currentThread(); + BatchContext( StreamFactory streamFactory, int maxSizeBytes, @@ -140,15 +178,19 @@ public TaskHandle add(BatchReference reference) throws InterruptedException { } void start() { + workers = new CountDownLatch(2); + Recv recv = new Recv(); messages = streamFactory.createStream(recv); events = recv; Send send = new Send(); - exec.execute(send); + sendExec.execute(send); } - void reconnect() { + void reconnect() throws InterruptedException { + workers.await(); + start(); } /** @@ -163,10 +205,59 @@ public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException { return add(taskHandle.retry()); } + /** + * Interrupt all subprocesses, notify the server, de-allocate resources, + * and abort the stream. + * + * @apiNote This is not a normal shutdown process. It is an abrupt termination + * triggered by an exception. + */ + private void abort(Throwable t) { + messages.onError(Status.INTERNAL.withCause(t).asRuntimeException()); + closed.completeExceptionally(t); + parent.interrupt(); + sendExec.shutdown(); + } + @Override public void close() throws IOException { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'close'"); + closeExec.execute(() -> { + try { + queue.put(TaskHandle.POISON); + workers.await(); + closed.complete(null); + } catch (Exception e) { + closed.completeExceptionally(e); + } + }); + + try { + closed.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + throw new IOException(e.getCause()); + } finally { + shutdownExecutors(); + } + } + + private void shutdownExecutors() { + BiConsumer> assertEmpty = (name, pending) -> { + assert pending.isEmpty() : "'%s' service had %d tasks awaiting execution" + .formatted(pending.size(), name); + }; + + List pending; + + pending = sendExec.shutdownNow(); + assertEmpty.accept("send", pending); + + pending = scheduledExec.shutdownNow(); + assertEmpty.accept("oom", pending); + + pending = closeExec.shutdownNow(); + assertEmpty.accept("close", pending); } /** Set the new state and notify awaiting threads. */ @@ -175,7 +266,9 @@ void setState(State nextState) { lock.lock(); try { + State prev = state; state = nextState; + state.onEnter(prev); stateChanged.signal(); } finally { lock.unlock(); @@ -193,42 +286,45 @@ void onEvent(Event event) throws InterruptedException { } private TaskHandle add(final TaskHandle taskHandle) throws InterruptedException { - // TODO(dyma): check that we haven't closed the stream on our end yet - // probably with some state.isClosed() or something - // TODO(dyma): check that wip doesn't have that ID yet, otherwise - // we can lose some data (results) + if (closed.isDone()) { + throw new IllegalStateException("BatchContext is closed"); + } + + TaskHandle existing = wip.get(taskHandle.id()); + if (existing != null) { + throw new DuplicateTaskException(taskHandle, existing); + } + + existing = wip.put(taskHandle.id(), taskHandle); + assert existing == null : "duplicate tasks in progress, id=" + existing.id(); + queue.put(taskHandle); return taskHandle; } private final class Send implements Runnable { + @Override public void run() { try { - // trySend exists normally trySend(); - messages.onCompleted(); - return; - } catch (InterruptedException ignored) { - // TODO(dyma): interrupted (whether through the exception - // by breaking the while loop. Restore the interrupted status - // and update the state + } catch (Exception e) { + messages.onError(e); + } finally { + workers.countDown(); } } - private void trySend() throws InterruptedException { + /** + * trySend consumes {@link #queue} tasks and sends them in batches until it + * encounters a {@link TaskHandle#POISON}. + * + *

    + * If the method returns normally, it means the queue's been drained. + */ + private void trySend() throws DataTooBigException { try { while (!Thread.currentThread().isInterrupted()) { - // if batch is full: - // -> if the stream is closed / status is error (error) return - // else send and await ack - // - // take the next item in the queue - // -> if POISON: drain the batch, call onComplete, return - // - // add to batch - - // TODO(dyma): check that the batch is if (batch.isFull()) { send(); } @@ -241,47 +337,40 @@ private void trySend() throws InterruptedException { } Data data = task.data(); - if (!batch.add(data)) { - // FIXME(dyma): once we've removed a task from the queue, we must - // ensure that it makes it's way to the batch, otherwise we lose - // that task. Here, for example, send() can be interrupted in which - // case the task is lost. - // How do we fix this? We cannot ignore the interrupt, because - // interrupted send() means the batch was not acked, so it will - // not accept any new items. - // - // Maybe! batch.add should put the data in the backlog if it couldn't - // fit it in the buffer!!! - // Yes!!!!!!! The backlog is not limited is size, so it will fit any - // data that does not exceed maxGrpcMessageSize. We wouldn't need to - // do a second pass ourselves. - send(); - boolean ok = batch.add(data); - assert ok : "batch.add must succeed after send"; - } + batch.add(data); - // TODO(dyma): check that the previous is null, - // we should've checked for that upstream in add(). TaskHandle existing = wip.put(task.id(), task); assert existing == null : "duplicate tasks in progress, id=" + existing.id(); } - } catch (DataTooBigException e) { - // TODO(dyma): fail + } catch (InterruptedException ignored) { + messages.onNext(Message.stop()); + messages.onCompleted(); + Thread.currentThread().interrupt(); } } + /** + * Send the current portion of batch items. After this method returns, the batch + * is guaranteed to have space for at least one the next item (not full). + */ private void send() throws InterruptedException { - // This will stop sending as soon as we get the batch not a "not full" state. - // The reason we do that is to account for the backlog, which might re-fill - // the batch's buffer after .clear(). + // Continue flushing until we get the batch to not a "not full" state. + // This is to account for the backlog, which might re-fill the batch + // after .clear(). while (batch.isFull()) { flush(); } assert !batch.isFull() : "batch is full after send"; } + /** + * Send all remaining items in the batch. After this method returns, the batch + * is guaranteed to be empty. + */ private void drain() throws InterruptedException { - // This will send until ALL items in the batch have been sent. + // To correctly drain the batch, we flush repeatedly + // until the batch becomes empty, as clearing a batch + // after an ACK might re-populate it from its internal backlog. while (!batch.isEmpty()) { flush(); } @@ -289,23 +378,18 @@ private void drain() throws InterruptedException { } private void flush() throws InterruptedException { - // TODO(dyma): if we're in OOM / ServerShuttingDown state, then we known there - // isn't any reason to keep waiting for the acks. However, we cannot exit - // without taking a poison pill from the queue, because this risks blocking the - // producer thread. - // So we patiently wait, relying purely on the 2 booleans: canSend and - // "isAcked". maybe not "isAcked" but "canAdd" / "canAccept"? - - // FIXME(dyma): draining the batch is not a good idea because the backlog - // is likely smaller that the maxSize, so we'd sending half-empty batches. - awaitCanSend(); messages.onNext(batch.prepare()); - setState(AWAIT_ACKS); - awaitAcked(); // TODO(dyma): rename canTake into something like awaitAcked(); - // the method can be called boolean isInFlight(); + setState(IN_FLIGHT); + + // When we get into OOM / ServerShuttingDown state, then we can be certain that + // there isn't any reason to keep waiting for the ACKs. However, we should not + // exit without either taking a poison pill from the queue, + // or being interrupted, as this risks blocking the producer (main) thread. + awaitCanPrepareNext(); } + /** Block until the current state allows {@link State#canSend}. */ private void awaitCanSend() throws InterruptedException { lock.lock(); try { @@ -317,17 +401,22 @@ private void awaitCanSend() throws InterruptedException { } } - // TODO(dyma): the semantics of "canTake" is rather "can I put more data in the - // batch", even more precisely -- "is the batch still in-flight or is it open?" - private void awaitAcked() throws InterruptedException { + /** + * Block until the current state allows {@link State#canPrepareNext}. + * + *

    + * Depending on the BatchContext lifecycle, the semantics of + * "await can prepare next" can be one of "message is ACK'ed" + * "the stream has started", or, more generally, + * "it is safe to take a next item from the queue and add it to the batch". + */ + private void awaitCanPrepareNext() throws InterruptedException { lock.lock(); try { - while (!state.canTake()) { + while (!state.canPrepareNext()) { stateChanged.await(); } } finally { - // Not a good assertion: batch could've been re-populated from the backlog. - // assert !batch.isFull() : "take allowed with full batch"; lock.unlock(); } } @@ -338,32 +427,50 @@ private final class Recv implements StreamObserver { @Override public void onNext(Event event) { try { - BatchContext.this.onEvent(event); + onEvent(event); } catch (InterruptedException e) { - // TODO(dyma): cancel the RPC (req.onError()) - } catch (Exception e) { - // TODO(dyma): cancel with + // Recv is running on a thread from gRPC's internal thread pool, + // so, while onEvent allows InterruptedException to stay responsive, + // in practice this thread will only be interrupted by the thread pool, + // which already knows it's being shut down. } } + /** + * EOF for the server-side stream. + * By the time this is called, the client-side of the stream had been closed + * and the "send" thread has either exited or is on its way there. + */ @Override public void onCompleted() { - // TODO(dyma): server closed its side of the stream successfully - // Maybe log, but there's nothing that we need to do here - // This is the EOF that the protocol document is talking about + workers.countDown(); + + // boolean stillStuffToDo = true; + // if (stillStuffToDo) { + // reconnect(); + // } } + /** An exception occurred either on our end or in the channel internals. */ @Override - public void onError(Throwable arg0) { - // TODO(dyma): if we did req.onError(), then the error can be ignored - // The exception should be set somewhere so all threads can observe it + public void onError(Throwable t) { + try { + onEvent(Event.RpcError.fromThrowable(t)); + } catch (InterruptedException ignored) { + // Recv is running on a thread from gRPC's internal thread pool, + // so, while onEvent allows InterruptedException to stay responsive, + // in practice this thread will only be interrupted by the thread pool, + // which already knows it's being shut down. + } finally { + workers.countDown(); + } } } - final State CLOSED = new BaseState(); - final State AWAIT_STARTED = new BaseState(BaseState.Action.TAKE) { + final State CLOSED = new BaseState("CLOSED"); + final State AWAIT_STARTED = new BaseState("AWAIT_STARTED", BaseState.Action.PREPARE_NEXT) { @Override - public void onEvent(Event event) { + public void onEvent(Event event) throws InterruptedException { if (requireNonNull(event, "event is null") == Event.STARTED) { setState(ACTIVE); return; @@ -371,20 +478,23 @@ public void onEvent(Event event) { super.onEvent(event); } }; - final State ACTIVE = new BaseState(BaseState.Action.TAKE, BaseState.Action.SEND); - final State AWAIT_ACKS = new BaseState() { + final State ACTIVE = new BaseState("ACTIVE", BaseState.Action.PREPARE_NEXT, BaseState.Action.SEND); + final State IN_FLIGHT = new BaseState("IN_FLIGHT") { @Override - public void onEvent(Event event) { + public void onEvent(Event event) throws InterruptedException { requireNonNull(event, "event is null"); if (event instanceof Event.Acks acks) { Collection remaining = batch.clear(); if (!remaining.isEmpty()) { - // TODO(dyma): throw an exception -- this is bad + throw ProtocolViolationException.incompleteAcks(List.copyOf(remaining)); } - // TODO(dyma): should we check if wip contains ID? - // TODO(dyma): do we need to synchronize here? I don't think so... - acks.acked().forEach(ack -> wip.get(ack).setAcked()); + acks.acked().forEach(id -> { + TaskHandle task = wip.get(id); + if (task != null) { + task.setAcked(); + } + }); setState(ACTIVE); } else if (event == Event.OOM) { int delaySeconds = 300; @@ -396,28 +506,34 @@ public void onEvent(Event event) { }; private class BaseState implements State { + private final String name; private final EnumSet permitted; enum Action { - TAKE, SEND; + PREPARE_NEXT, SEND; } - protected BaseState(Action... actions) { + protected BaseState(String name, Action... actions) { + this.name = name; this.permitted = EnumSet.copyOf(Arrays.asList(requireNonNull(actions, "actions is null"))); } + @Override + public void onEnter(State prev) { + } + @Override public boolean canSend() { return permitted.contains(Action.SEND); } @Override - public boolean canTake() { - return permitted.contains(Action.TAKE); + public boolean canPrepareNext() { + return permitted.contains(Action.PREPARE_NEXT); } @Override - public void onEvent(Event event) { + public void onEvent(Event event) throws InterruptedException { requireNonNull(event, "event is null"); if (event instanceof Event.Results results) { @@ -427,20 +543,38 @@ public void onEvent(Event event) { batch.setMaxSize(backoff.maxSize()); } else if (event == Event.SHUTTING_DOWN) { setState(new ServerShuttingDown(this)); + } else if (event instanceof Event.RpcError error) { + setState(new StreamAborted(error.exception())); } else { - throw new IllegalStateException("cannot handle " + event.getClass()); + throw ProtocolViolationException.illegalStateTransition(this, event); } } + + @Override + public String toString() { + return name; + } } - private class Oom extends BaseState { - private final ScheduledFuture shutdown; + /** + * Oom waits for {@link Event#SHUTTING_DOWN} up to a specified amount of time, + * after which it will force stream termiation by imitating server shutdown. + */ + private final class Oom extends BaseState { + private final long delaySeconds; + private ScheduledFuture shutdown; private Oom(long delaySeconds) { - super(); - this.shutdown = scheduledExec.schedule(this::initiateShutdown, delaySeconds, TimeUnit.SECONDS); + super("OOM"); + this.delaySeconds = delaySeconds; + } + + @Override + public void onEnter(State prev) { + shutdown = scheduledExec.schedule(this::initiateShutdown, delaySeconds, TimeUnit.SECONDS); } + /** Imitate server shutdown sequence. */ private void initiateShutdown() { if (Thread.currentThread().isInterrupted()) { return; @@ -450,26 +584,38 @@ private void initiateShutdown() { } @Override - public void onEvent(Event event) { - if (requireNonNull(event, "event is null") != Event.SHUTTING_DOWN) { - throw new IllegalStateException("Expected OOM to be followed by ShuttingDown"); + public void onEvent(Event event) throws InterruptedException { + requireNonNull(event, "event"); + if (event == Event.SHUTTING_DOWN || event instanceof RpcError) { + shutdown.cancel(true); + try { + shutdown.get(); + } catch (CancellationException ignored) { + } catch (ExecutionException e) { + throw new RuntimeException(e); + } } - - shutdown.cancel(true); - setState(new ServerShuttingDown(this)); + super.onEvent(event); } } - private class ServerShuttingDown implements State { - private final boolean canTake; + /** + * ServerShuttingDown allows preparing the next batch + * unless the server's OOM'ed on the previous one. + * Once set, the state will shutdown {@link BatchContext#sendExec} + * to instruct the "send" thread to close our part of the stream. + */ + private final class ServerShuttingDown extends BaseState { + private final boolean canPrepareNext; private ServerShuttingDown(State previous) { - this.canTake = requireNonNull(previous, "previous is null").getClass() == Oom.class; + super("SERVER_SHUTTING_DOWN"); + this.canPrepareNext = requireNonNull(previous, "previous is null").getClass() != Oom.class; } @Override - public boolean canTake() { - return canTake; + public boolean canPrepareNext() { + return canPrepareNext; } @Override @@ -477,12 +623,44 @@ public boolean canSend() { return false; } + @Override + public void onEnter(State prev) { + sendExec.shutdown(); + } + + // TODO(dyma): if we agree to retire Shutdown, then ServerShuttingDown + // should not override onEvent and let it fallthough to + // ProtocolViolationException on any event. @Override public void onEvent(Event event) throws InterruptedException { - if (requireNonNull(event, "event is null") != Event.SHUTDOWN) { - throw new IllegalStateException("Expected ShuttingDown to be followed by Shutdown"); + if (requireNonNull(event, "event is null") == Event.SHUTDOWN) { + return; } - setState(CLOSED); + super.onEvent(event); + } + } + + /** + * StreamAborted means the RPC is "dead": the {@link messages} stream is closed + * and using it will result in an {@link IllegalStateException}. + */ + private final class StreamAborted extends BaseState { + private final Throwable t; + + protected StreamAborted(Throwable t) { + super("STREAM_ABORTED"); + this.t = t; + } + + @Override + public void onEnter(State prev) { + abort(t); + } + + @Override + public void onEvent(Event event) { + // StreamAborted cannot transition into another state. It is terminal -- + // BatchContext MUST terminate its subprocesses and close exceptionally. } } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java index e6adb4c12..5160bd27c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java @@ -1,11 +1,16 @@ package io.weaviate.client6.v1.api.collections.batch; +import static java.util.Objects.requireNonNull; + +import io.weaviate.client6.v1.api.WeaviateException; + /** * DataTooBigException is thrown when a single object exceeds * the maximum size of a gRPC message. */ -public class DataTooBigException extends Exception { +public class DataTooBigException extends WeaviateException { DataTooBigException(Data data, long maxSizeBytes) { - super("%s with size=%dB exceeds maximum message size %dB".formatted(data, data.sizeBytes(), maxSizeBytes)); + super("%s with size=%dB exceeds maximum message size %dB".formatted( + requireNonNull(data, "data is null"), data.sizeBytes(), maxSizeBytes)); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java new file mode 100644 index 000000000..f0949a6c0 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java @@ -0,0 +1,23 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import io.weaviate.client6.v1.api.WeaviateException; + +/** + * DuplicateTaskException is thrown if task is submitted to the batch + * while another task with the same ID is in progress. + */ +public class DuplicateTaskException extends WeaviateException { + private final TaskHandle existing; + + DuplicateTaskException(TaskHandle duplicate, TaskHandle existing) { + super("%s cannot be added to the batch while another task with the same ID is in progress"); + this.existing = existing; + } + + /** + * Get the currently in-progress handle that's a duplicate of the one submitted. + */ + public TaskHandle getExisting() { + return existing; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java index 855187bd6..03555907c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java @@ -6,14 +6,16 @@ import java.util.List; import java.util.Map; +import io.grpc.Status; import io.weaviate.client6.v1.api.collections.batch.Event.Acks; import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; import io.weaviate.client6.v1.api.collections.batch.Event.Results; +import io.weaviate.client6.v1.api.collections.batch.Event.RpcError; import io.weaviate.client6.v1.api.collections.batch.Event.Started; import io.weaviate.client6.v1.api.collections.batch.Event.TerminationEvent; sealed interface Event - permits Started, Acks, Results, Backoff, TerminationEvent { + permits Started, Acks, Results, Backoff, TerminationEvent, RpcError { final static Event STARTED = new Started(); final static Event OOM = TerminationEvent.OOM; @@ -36,6 +38,11 @@ record Acks(Collection acked) implements Event { public Acks { acked = List.copyOf(requireNonNull(acked, "acked is null")); } + + @Override + public String toString() { + return "Acks"; + } } /** @@ -50,6 +57,11 @@ record Results(Collection successful, Map errors) implem successful = List.copyOf(requireNonNull(successful, "successful is null")); errors = Map.copyOf(requireNonNull(errors, "errors is null")); } + + @Override + public String toString() { + return "Results"; + } } /** @@ -69,6 +81,11 @@ record Results(Collection successful, Map errors) implem * message limit in a new {@link BatchContext}, but is not required to. */ record Backoff(int maxSize) implements Event { + + @Override + public String toString() { + return "Backoff"; + } } enum TerminationEvent implements Event { @@ -113,4 +130,10 @@ enum TerminationEvent implements Event { SHUTDOWN; } + record RpcError(Exception exception) implements Event { + static RpcError fromThrowable(Throwable t) { + Status status = Status.fromThrowable(t); + return new RpcError(status.asException()); + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java new file mode 100644 index 000000000..d97b4839c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java @@ -0,0 +1,48 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.List; + +import io.weaviate.client6.v1.api.WeaviateException; + +/** + * ProtocolViolationException describes unexpected server behavior in violation + * of the SSB protocol. + * + *

    + * This exception cannot be handled in a meaningful way and should be reported + * to the upstream Weaviate + * project. + */ +public class ProtocolViolationException extends WeaviateException { + ProtocolViolationException(String message) { + super(message); + } + + /** + * Protocol violated because an event arrived while the client is in a state + * which doesn't expect to handle this event. + * + * @param current Current {@link BatchContext} state. + * @param event Server-side event. + * @return ProtocolViolationException with a formatted message. + */ + static ProtocolViolationException illegalStateTransition(State current, Event event) { + return new ProtocolViolationException("%s arrived in %s state".formatted(event, current)); + } + + /** + * Protocol violated because some tasks from the previous Data message + * are not present in the Acks message. + * + * @param remaining IDs of the tasks that weren't ack'ed. MUST be non-empty. + * @return ProtocolViolationException with a formatted message. + */ + static ProtocolViolationException incompleteAcks(List remaining) { + requireNonNull(remaining, "remaining is null"); + return new ProtocolViolationException("IDs from previous Data message missing in Acks: '%s', ... (%d more)" + .formatted(remaining.get(0), remaining.size() - 1)); + + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java index b16a2be94..b3e1e8513 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java @@ -8,10 +8,23 @@ interface State { boolean canSend(); /** - * canTake returns a boolean indicating if accepting + * canPrepareNext returns a boolean indicating if accepting * more items into the batch is allowed in this state. */ - boolean canTake(); + boolean canPrepareNext(); + + /** + * Lifecycle hook that's called after the state is set. + * + *

    + *
    + * This hook MUST be called exactly once. + *
    + * The next state MUST NOT be set until onEnter returns. + * + * @param prev Previous state or null. + */ + void onEnter(State prev); /** * onEvent handles incoming events; these can be generated by the server diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java index 7752246e6..ea93afffe 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java @@ -2,6 +2,7 @@ import static java.util.Objects.requireNonNull; +import java.time.Instant; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -47,6 +48,9 @@ public final record Result(Optional error) { /** The number of times this task has been retried. */ private final int retries; + /** Task creation timestamp. */ + private final Instant createdAt = Instant.now(); + private TaskHandle(Data data, int retries) { this.data = requireNonNull(data, "data is null"); @@ -163,4 +167,14 @@ public CompletableFuture result() { public int timesRetried() { return retries; } + + /** Task creation timestamp. Retried tasks have different timestamps. */ + public Instant createdAt() { + return createdAt; + } + + @Override + public String toString() { + return "TaskHandle".formatted(id(), timesRetried(), createdAt()); + } } From 7e22f3929e4d5ee84620aab8ca733bdc9b7935c9 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Mon, 16 Feb 2026 18:08:38 +0100 Subject: [PATCH 07/22] wip(batch): implement shutdown policies --- .../api/collections/batch/BatchContext.java | 113 ++++++++++-------- .../v1/api/collections/batch/State.java | 2 +- 2 files changed, 67 insertions(+), 48 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index b896f6b9c..7fca98aca 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -129,12 +129,6 @@ public final class BatchContext implements Closeable { */ private volatile StreamObserver messages; - /** - * Server-side part of the current stream, created on {@link #start}. - * Other threads MAY use stream but MUST NOT update this field on their own. - */ - private volatile StreamObserver events; - /** * Latch reaches zero once both "send" (client side) and "recv" (server side) * parts of the stream have closed. After a {@link reconnect}, the latch is @@ -142,7 +136,7 @@ public final class BatchContext implements Closeable { */ private volatile CountDownLatch workers; - /** done completes the stream. */ + /** closed completes the stream. */ private final CompletableFuture closed = new CompletableFuture<>(); /** Thread which created the BatchContext. */ @@ -182,7 +176,6 @@ void start() { Recv recv = new Recv(); messages = streamFactory.createStream(recv); - events = recv; Send send = new Send(); sendExec.execute(send); @@ -275,8 +268,8 @@ void setState(State nextState) { } } - /** onEvent delegates event handling to {@link #state} */ - void onEvent(Event event) throws InterruptedException { + /** onEvent delegates event handling to {@link #state}. */ + void onEvent(Event event) { lock.lock(); try { state.onEvent(event); @@ -295,9 +288,6 @@ private TaskHandle add(final TaskHandle taskHandle) throws InterruptedException throw new DuplicateTaskException(taskHandle, existing); } - existing = wip.put(taskHandle.id(), taskHandle); - assert existing == null : "duplicate tasks in progress, id=" + existing.id(); - queue.put(taskHandle); return taskHandle; } @@ -308,8 +298,6 @@ private final class Send implements Runnable { public void run() { try { trySend(); - } catch (Exception e) { - messages.onError(e); } finally { workers.countDown(); } @@ -317,12 +305,9 @@ public void run() { /** * trySend consumes {@link #queue} tasks and sends them in batches until it - * encounters a {@link TaskHandle#POISON}. - * - *

    - * If the method returns normally, it means the queue's been drained. + * encounters a {@link TaskHandle#POISON} or is otherwise interrupted. */ - private void trySend() throws DataTooBigException { + private void trySend() { try { while (!Thread.currentThread().isInterrupted()) { if (batch.isFull()) { @@ -343,10 +328,15 @@ private void trySend() throws DataTooBigException { assert existing == null : "duplicate tasks in progress, id=" + existing.id(); } } catch (InterruptedException ignored) { - messages.onNext(Message.stop()); - messages.onCompleted(); - Thread.currentThread().interrupt(); + // Allow this method to exit normally to close our end of the stream. + } catch (Exception e) { + onEvent(new Event.RpcError(e)); + messages.onError(e); + return; } + + messages.onNext(Message.stop()); + messages.onCompleted(); } /** @@ -426,14 +416,7 @@ private final class Recv implements StreamObserver { @Override public void onNext(Event event) { - try { - onEvent(event); - } catch (InterruptedException e) { - // Recv is running on a thread from gRPC's internal thread pool, - // so, while onEvent allows InterruptedException to stay responsive, - // in practice this thread will only be interrupted by the thread pool, - // which already knows it's being shut down. - } + onEvent(event); } /** @@ -445,22 +428,39 @@ public void onNext(Event event) { public void onCompleted() { workers.countDown(); - // boolean stillStuffToDo = true; - // if (stillStuffToDo) { - // reconnect(); - // } - } + // TODO(dyma): I'm not sure if there isn't a race here. + // Can wip be empty but there still be more work to do? + // wip can be cleared after Results message for the last + // batch arrives. This can only happen in ACTIVE stage, + // because OOM and SERVER_SHUTTING_DOWN imply the last batch + // was not acked. + // Should we double-check that queue.isEmpty()? Can it still + // contain a POISON at this time? If it's a normal shutdown + // after #close, then we've already taken the POISON and sent Stop. + // If it happends after SERVER_SHUTTING_DOWN, same. + if (wip.isEmpty()) { + return; + } - /** An exception occurred either on our end or in the channel internals. */ - @Override - public void onError(Throwable t) { try { - onEvent(Event.RpcError.fromThrowable(t)); + reconnect(); } catch (InterruptedException ignored) { // Recv is running on a thread from gRPC's internal thread pool, // so, while onEvent allows InterruptedException to stay responsive, // in practice this thread will only be interrupted by the thread pool, // which already knows it's being shut down. + } + } + + /** An exception occurred either on our end or in the channel internals. */ + @Override + public void onError(Throwable t) { + // TODO(dyma): treat this as a StreamHangup + // After re-connecting, we need to re-submit all WIP tasks, which aren't + // in the queue. Maybe this means we should only add to WIP once we add it + // to the batch? Need to muse on that. + try { + onEvent(Event.RpcError.fromThrowable(t)); } finally { workers.countDown(); } @@ -470,7 +470,7 @@ public void onError(Throwable t) { final State CLOSED = new BaseState("CLOSED"); final State AWAIT_STARTED = new BaseState("AWAIT_STARTED", BaseState.Action.PREPARE_NEXT) { @Override - public void onEvent(Event event) throws InterruptedException { + public void onEvent(Event event) { if (requireNonNull(event, "event is null") == Event.STARTED) { setState(ACTIVE); return; @@ -481,7 +481,7 @@ public void onEvent(Event event) throws InterruptedException { final State ACTIVE = new BaseState("ACTIVE", BaseState.Action.PREPARE_NEXT, BaseState.Action.SEND); final State IN_FLIGHT = new BaseState("IN_FLIGHT") { @Override - public void onEvent(Event event) throws InterruptedException { + public void onEvent(Event event) { requireNonNull(event, "event is null"); if (event instanceof Event.Acks acks) { @@ -533,7 +533,7 @@ public boolean canPrepareNext() { } @Override - public void onEvent(Event event) throws InterruptedException { + public void onEvent(Event event) { requireNonNull(event, "event is null"); if (event instanceof Event.Results results) { @@ -579,18 +579,23 @@ private void initiateShutdown() { if (Thread.currentThread().isInterrupted()) { return; } - events.onNext(Event.SHUTTING_DOWN); - events.onNext(Event.SHUTDOWN); + onEvent(Event.SHUTTING_DOWN); + onEvent(Event.SHUTDOWN); } @Override - public void onEvent(Event event) throws InterruptedException { + public void onEvent(Event event) { requireNonNull(event, "event"); if (event == Event.SHUTTING_DOWN || event instanceof RpcError) { shutdown.cancel(true); try { shutdown.get(); } catch (CancellationException ignored) { + } catch (InterruptedException ignored) { + // Recv is running on a thread from gRPC's internal thread pool, + // so, while onEvent allows InterruptedException to stay responsive, + // in practice this thread will only be interrupted by the thread pool, + // which already knows it's being shut down. } catch (ExecutionException e) { throw new RuntimeException(e); } @@ -625,6 +630,13 @@ public boolean canSend() { @Override public void onEnter(State prev) { + // FIXME(dyma): we should have an orderly shutdown with a poison pill here! + // After the reconnect we would want to continue the execution like before. + // Remember, "reconnect" after a normal shutdown should only affect the + // `messages` stream and re-create Recv. The rest of the state is preserved. + // + // If we #shutdown the service it won't be able to execute + // the next Send routine. sendExec.shutdown(); } @@ -632,7 +644,7 @@ public void onEnter(State prev) { // should not override onEvent and let it fallthough to // ProtocolViolationException on any event. @Override - public void onEvent(Event event) throws InterruptedException { + public void onEvent(Event event) { if (requireNonNull(event, "event is null") == Event.SHUTDOWN) { return; } @@ -654,6 +666,13 @@ protected StreamAborted(Throwable t) { @Override public void onEnter(State prev) { + // FIXME(dyma): we need to differentiate between "ABORT" the stream + // and "RECONNECT" on server hangup. + // ABORT should be for sender errors, where we cannot fix something + // by reconnecting. This should be called InternalError state. + // + // RECONNECT deals with StreamHangup event (from Recv#onError). + // We should just reconnect and re-submit all WIP items. abort(t); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java index b3e1e8513..a4ba41310 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java @@ -38,5 +38,5 @@ interface State { * different state via {@link BatchContext#setState(State)}, or start * a separate process, e.g. the OOM timer. */ - void onEvent(Event event) throws InterruptedException; + void onEvent(Event event); } From abb2d78905dc309aa4b52cad4933a2e475a503a3 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 17 Feb 2026 14:56:51 +0100 Subject: [PATCH 08/22] feat(batch): implement reconnect and shutdown policies --- .../v1/api/collections/batch/Batch.java | 69 +++- .../api/collections/batch/BatchContext.java | 380 +++++++++++------- .../v1/api/collections/batch/Event.java | 80 ++-- .../v1/api/collections/batch/TaskHandle.java | 7 +- .../batch/TranslatingStreamFactory.java | 5 +- 5 files changed, 351 insertions(+), 190 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java index 6cb9d6051..4a9b156d5 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java @@ -2,15 +2,15 @@ import static java.util.Objects.requireNonNull; -import java.util.ArrayList; +import java.time.Instant; import java.util.Collection; -import java.util.Collections; +import java.util.Comparator; import java.util.Iterator; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import java.util.OptionalInt; import java.util.Set; +import java.util.TreeSet; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -49,7 +49,7 @@ @ThreadSafe final class Batch { /** Backlog MUST be confined to the "receiver" thread. */ - private final List backlog = new ArrayList<>(); + private final TreeSet backlog = new TreeSet<>(BacklogItem.comparator()); /** * Items stored in this batch. @@ -175,22 +175,17 @@ synchronized void setMaxSize(int maxSizeNew) { return; } - // Buffer exceeds the new limit. - // Move extra items to the backlog in LIFO order. + // Buffer exceeds the new limit. Move extra items to the backlog (LIFO). Iterator> extra = buffer.reversed() .entrySet().stream() .limit(buffer.size() - maxSize) .iterator(); while (extra.hasNext()) { - Map.Entry next = extra.next(); - backlog.add(next.getValue()); + Data data = extra.next().getValue(); + addBacklog(data); extra.remove(); } - // Reverse the backlog to restore the FIFO order. - // FIXME(dyma): this assumes setMaxSize is called on an empty backlog, - // but that's is simply not true. - Collections.reverse(backlog); } finally { checkInvariants(); } @@ -205,7 +200,7 @@ synchronized void setMaxSize(int maxSizeNew) { * will overflow the batch before it's removed from the queue, the simplest * and safest way to deal with it is to allow {@link Batch} to put * the overflowing item in the {@link #backlog}. The batch is considered - * full after that and will not accept any more items until it's cleared. + * full after that. * * @throws DataTooBigException If the data exceeds the maximum * possible batch size. @@ -236,7 +231,7 @@ synchronized void add(Data data) throws IllegalStateException, DataTooBigExcepti // has been reached to satisfy the invariant. // This doubles as a safeguard to ensure the caller cannot add any // more items to the batch before flushing it. - backlog.add(data); + addBacklog(data); sizeBytes += remainingBytes; assert isFull() : "batch must be full after an overflow"; } finally { @@ -255,6 +250,11 @@ private synchronized void addSafe(Data data) { sizeBytes += data.sizeBytes(); } + /** Add a data item to the {@link #backlog}. */ + private synchronized void addBacklog(Data data) { + backlog.add(new BacklogItem(data)); + } + /** * Clear this batch's internal buffer. * @@ -285,6 +285,7 @@ synchronized Collection clear() { // exceed maxSizeBytes. backlog.stream() .takeWhile(__ -> !isFull()) + .map(BacklogItem::data) .forEach(this::addSafe); return removed; @@ -293,6 +294,46 @@ synchronized Collection clear() { } } + private static record BacklogItem(Data data, Instant createdAt) { + public BacklogItem { + requireNonNull(data, "data is null"); + requireNonNull(createdAt, "createdAt is null"); + } + + /** + * This constructor sets {@link #createdAt} automatically. + * It is not important that this timestamp is different from + * the one in {@link TaskHandle}, as longs as the order is correct. + */ + public BacklogItem(Data data) { + this(data, Instant.now()); + } + + /** Comparator sorts BacklogItems by their creation time. */ + private static Comparator comparator() { + return new Comparator() { + + @Override + public int compare(BacklogItem a, BacklogItem b) { + if (a.equals(b)) { + return 0; + } + + int cmpInstant = a.createdAt.compareTo(b.createdAt); + boolean sameInstant = cmpInstant == 0; + if (sameInstant) { + // We cannot return 0 for two items with different + // contents, as it may result in data loss. + // If they were somehow created in the same instant, + // let them be sorted lexicographically. + return a.data.id().compareTo(b.data.id()); + } + return cmpInstant; + } + }; + } + } + /** Asserts the invariants of this class. */ private synchronized void checkInvariants() { assert maxSize > 0 : "non-positive maxSize"; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 7fca98aca..4eb34d4a2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -8,6 +8,7 @@ import java.util.Collection; import java.util.EnumSet; import java.util.List; +import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CancellationException; @@ -18,6 +19,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -32,7 +34,8 @@ import io.grpc.stub.StreamObserver; import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.WeaviateObject; -import io.weaviate.client6.v1.api.collections.batch.Event.RpcError; +import io.weaviate.client6.v1.api.collections.batch.Event.ClientError; +import io.weaviate.client6.v1.api.collections.batch.Event.StreamHangup; import io.weaviate.client6.v1.api.collections.data.BatchReference; import io.weaviate.client6.v1.api.collections.data.InsertManyRequest; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; @@ -52,10 +55,32 @@ public final class BatchContext implements Closeable { private final int DEFAULT_BATCH_SIZE = 1000; private final int DEFAULT_QUEUE_SIZE = 100; + private final int MAX_RECONNECT_RETRIES = 5; private final CollectionDescriptor collectionDescriptor; private final CollectionHandleDefaults collectionHandleDefaults; + /** + * Internal execution service. It's lifecycle is bound to that of the + * BatchContext: it's started when the context is initialized + * and shutdown on {@link #close}. + * + *

    + * In the event of abrupt stream termination ({@link Recv#onError} is called), + * the "recv" thread MAY shutdown this service in order to interrupt the "send" + * thread; the latter may be blocked on {@link Send#awaitCanSend} or + * {@link Send#awaitCanPrepareNext}. + */ + private final ExecutorService sendExec = Executors.newSingleThreadExecutor(); + + /** + * Scheduled thread pool for delayed tasks. + * + * @see Oom + * @see Reconnecting + */ + private final ScheduledExecutorService scheduledExec = Executors.newScheduledThreadPool(1); + /** Stream factory creates new streams. */ private final StreamFactory streamFactory; @@ -100,35 +125,15 @@ public final class BatchContext implements Closeable { /** stateChanged notifies threads about a state transition. */ private final Condition stateChanged = lock.newCondition(); - /** - * Internal execution service. It's lifecycle is bound to that of the - * BatchContext: it's started when the context is initialized - * and shutdown on {@link #close}. - * - *

    - * In the event of abrupt stream termination ({@link Recv#onError} is called), - * the "recv" thread MAY shutdown this service in order to interrupt the "send" - * thread; the latter may be blocked on {@link Send#awaitCanSend} or - * {@link Send#awaitCanPrepareNext}. - */ - private final ExecutorService sendExec = Executors.newSingleThreadExecutor(); - - /** - * Scheduled thread pool for OOM timer. - * - * @see Oom - */ - private final ScheduledExecutorService scheduledExec = Executors.newScheduledThreadPool(1); - - /** Service executor for polling {@link #workers} status before closing. */ - private final ExecutorService closeExec = Executors.newSingleThreadExecutor(); - /** * Client-side part of the current stream, created on {@link #start}. * Other threads MAY use stream but MUST NOT update this field on their own. */ private volatile StreamObserver messages; + /** Handle for the "send" thread. Use {@link Future#cancel} to interrupt it. */ + private volatile Future send; + /** * Latch reaches zero once both "send" (client side) and "recv" (server side) * parts of the stream have closed. After a {@link reconnect}, the latch is @@ -136,11 +141,23 @@ public final class BatchContext implements Closeable { */ private volatile CountDownLatch workers; - /** closed completes the stream. */ - private final CompletableFuture closed = new CompletableFuture<>(); + /** Lightway check to ensure users cannot send on a closed context. */ + private volatile boolean closed; - /** Thread which created the BatchContext. */ - private final Thread parent = Thread.currentThread(); + /** Closing state. */ + private volatile Closing closing; + + void setClosing(Exception ex) { + if (closing == null) { + synchronized (Closing.class) { + if (closing == null) { + closing = new Closing(ex); + } + } + } + + setState(closing); + } BatchContext( StreamFactory streamFactory, @@ -153,6 +170,7 @@ public final class BatchContext implements Closeable { this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes); + setState(CLOSED); } /** Add {@link WeaviateObject} to the batch. */ @@ -174,15 +192,20 @@ public TaskHandle add(BatchReference reference) throws InterruptedException { void start() { workers = new CountDownLatch(2); - Recv recv = new Recv(); - messages = streamFactory.createStream(recv); + messages = streamFactory.createStream(new Recv()); + send = sendExec.submit(new Send()); - Send send = new Send(); - sendExec.execute(send); + messages.onNext(Message.start(collectionHandleDefaults.consistencyLevel())); + setState(AWAIT_STARTED); } - void reconnect() throws InterruptedException { + /** + * Reconnect waits for "send" and "recv" streams to exit + * and restarts the process with a new stream. + */ + void reconnect() throws InterruptedException, ExecutionException { workers.await(); + send.get(); start(); } @@ -198,40 +221,19 @@ public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException { return add(taskHandle.retry()); } - /** - * Interrupt all subprocesses, notify the server, de-allocate resources, - * and abort the stream. - * - * @apiNote This is not a normal shutdown process. It is an abrupt termination - * triggered by an exception. - */ - private void abort(Throwable t) { - messages.onError(Status.INTERNAL.withCause(t).asRuntimeException()); - closed.completeExceptionally(t); - parent.interrupt(); - sendExec.shutdown(); - } - @Override public void close() throws IOException { - closeExec.execute(() -> { - try { - queue.put(TaskHandle.POISON); - workers.await(); - closed.complete(null); - } catch (Exception e) { - closed.completeExceptionally(e); - } - }); + setClosing(null); try { - closed.get(); + closing.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (ExecutionException e) { throw new IOException(e.getCause()); } finally { shutdownExecutors(); + setState(CLOSED); } } @@ -249,7 +251,7 @@ private void shutdownExecutors() { pending = scheduledExec.shutdownNow(); assertEmpty.accept("oom", pending); - pending = closeExec.shutdownNow(); + pending = closing.shutdownNow(); assertEmpty.accept("close", pending); } @@ -269,7 +271,7 @@ void setState(State nextState) { } /** onEvent delegates event handling to {@link #state}. */ - void onEvent(Event event) { + private void onEvent(Event event) { lock.lock(); try { state.onEvent(event); @@ -279,7 +281,7 @@ void onEvent(Event event) { } private TaskHandle add(final TaskHandle taskHandle) throws InterruptedException { - if (closed.isDone()) { + if (closed) { throw new IllegalStateException("BatchContext is closed"); } @@ -309,6 +311,8 @@ public void run() { */ private void trySend() { try { + awaitCanPrepareNext(); + while (!Thread.currentThread().isInterrupted()) { if (batch.isFull()) { send(); @@ -328,10 +332,12 @@ private void trySend() { assert existing == null : "duplicate tasks in progress, id=" + existing.id(); } } catch (InterruptedException ignored) { - // Allow this method to exit normally to close our end of the stream. + // This thread is only interrupted in the RECONNECTING state, not by + // the user's code. Allow this method to exit normally to close our + // end of the stream. + Thread.currentThread().interrupt(); } catch (Exception e) { - onEvent(new Event.RpcError(e)); - messages.onError(e); + onEvent(new Event.ClientError(e)); return; } @@ -426,56 +432,32 @@ public void onNext(Event event) { */ @Override public void onCompleted() { - workers.countDown(); - - // TODO(dyma): I'm not sure if there isn't a race here. - // Can wip be empty but there still be more work to do? - // wip can be cleared after Results message for the last - // batch arrives. This can only happen in ACTIVE stage, - // because OOM and SERVER_SHUTTING_DOWN imply the last batch - // was not acked. - // Should we double-check that queue.isEmpty()? Can it still - // contain a POISON at this time? If it's a normal shutdown - // after #close, then we've already taken the POISON and sent Stop. - // If it happends after SERVER_SHUTTING_DOWN, same. - if (wip.isEmpty()) { - return; - } - try { - reconnect(); - } catch (InterruptedException ignored) { - // Recv is running on a thread from gRPC's internal thread pool, - // so, while onEvent allows InterruptedException to stay responsive, - // in practice this thread will only be interrupted by the thread pool, - // which already knows it's being shut down. + onEvent(Event.EOF); + } finally { + workers.countDown(); } } /** An exception occurred either on our end or in the channel internals. */ @Override public void onError(Throwable t) { - // TODO(dyma): treat this as a StreamHangup - // After re-connecting, we need to re-submit all WIP tasks, which aren't - // in the queue. Maybe this means we should only add to WIP once we add it - // to the batch? Need to muse on that. try { - onEvent(Event.RpcError.fromThrowable(t)); + onEvent(Event.StreamHangup.fromThrowable(t)); } finally { workers.countDown(); } } } - final State CLOSED = new BaseState("CLOSED"); final State AWAIT_STARTED = new BaseState("AWAIT_STARTED", BaseState.Action.PREPARE_NEXT) { @Override public void onEvent(Event event) { if (requireNonNull(event, "event is null") == Event.STARTED) { setState(ACTIVE); - return; + } else { + super.onEvent(event); } - super.onEvent(event); } }; final State ACTIVE = new BaseState("ACTIVE", BaseState.Action.PREPARE_NEXT, BaseState.Action.SEND); @@ -496,9 +478,8 @@ public void onEvent(Event event) { } }); setState(ACTIVE); - } else if (event == Event.OOM) { - int delaySeconds = 300; - setState(new Oom(delaySeconds)); + } else if (event instanceof Event.Oom oom) { + setState(new Oom(oom.delaySeconds())); } else { super.onEvent(event); } @@ -532,24 +513,64 @@ public boolean canPrepareNext() { return permitted.contains(Action.PREPARE_NEXT); } + /** + * Handle events which may arrive at any moment without violating the protocol. + * + *

      + *
    • {@link Event.Results} -- update tasks in {@link wip} and remove them. + *
    • {@link Event.Backoff} -- adjust batch size. + *
    • {@link Event#SHUTTING_DOWN} -- transition into + * {@link ServerShuttingDown}. + *
    • {@link Event.StreamHangup -- transition into {@link Reconnecting} state. + *
    • {@link Event.ClientError -- transition into {@link Closing} state with + * exception. + *
    + * + * @throws ProtocolViolationException If event cannot be handled in this state. + */ @Override public void onEvent(Event event) { requireNonNull(event, "event is null"); if (event instanceof Event.Results results) { - results.successful().forEach(id -> wip.get(id).setSuccess()); - results.errors().forEach((id, error) -> wip.get(id).setError(error)); + onResults(results); } else if (event instanceof Event.Backoff backoff) { - batch.setMaxSize(backoff.maxSize()); + onBackoff(backoff); } else if (event == Event.SHUTTING_DOWN) { - setState(new ServerShuttingDown(this)); - } else if (event instanceof Event.RpcError error) { - setState(new StreamAborted(error.exception())); + onShuttingDown(); + } else if (event instanceof Event.StreamHangup || event == Event.EOF) { + onStreamClosed(event); + } else if (event instanceof Event.ClientError error) { + onClientError(error); } else { throw ProtocolViolationException.illegalStateTransition(this, event); } } + private final void onResults(Event.Results results) { + results.successful().forEach(id -> wip.remove(id).setSuccess()); + results.errors().forEach((id, error) -> wip.remove(id).setError(error)); + } + + private final void onBackoff(Event.Backoff backoff) { + batch.setMaxSize(backoff.maxSize()); + } + + private final void onShuttingDown() { + setState(new ServerShuttingDown(this)); + } + + private final void onStreamClosed(Event event) { + if (event instanceof Event.StreamHangup hangup) { + // TODO(dyma): log error? + } + setState(new Reconnecting(MAX_RECONNECT_RETRIES)); + } + + private final void onClientError(Event.ClientError error) { + setClosing(error.exception()); + } + @Override public String toString() { return name; @@ -580,13 +601,15 @@ private void initiateShutdown() { return; } onEvent(Event.SHUTTING_DOWN); - onEvent(Event.SHUTDOWN); + onEvent(Event.EOF); } @Override public void onEvent(Event event) { requireNonNull(event, "event"); - if (event == Event.SHUTTING_DOWN || event instanceof RpcError) { + if (event == Event.SHUTTING_DOWN || + event instanceof StreamHangup || + event instanceof ClientError) { shutdown.cancel(true); try { shutdown.get(); @@ -630,56 +653,135 @@ public boolean canSend() { @Override public void onEnter(State prev) { - // FIXME(dyma): we should have an orderly shutdown with a poison pill here! - // After the reconnect we would want to continue the execution like before. - // Remember, "reconnect" after a normal shutdown should only affect the - // `messages` stream and re-create Recv. The rest of the state is preserved. - // - // If we #shutdown the service it won't be able to execute - // the next Send routine. - sendExec.shutdown(); - } - - // TODO(dyma): if we agree to retire Shutdown, then ServerShuttingDown - // should not override onEvent and let it fallthough to - // ProtocolViolationException on any event. + send.cancel(true); + } + } + + private final class Reconnecting extends BaseState { + private final int maxRetries; + private int retries = 0; + + private Reconnecting(int maxRetries) { + super("RECONNECTING", Action.PREPARE_NEXT); + this.maxRetries = maxRetries; + } + + @Override + public void onEnter(State prev) { + send.cancel(true); + + if (prev.getClass() != ServerShuttingDown.class) { + // This is NOT an orderly shutdown, we're reconnecting after a stream hangup. + // Assume all WIP items have been lost and re-submit everything. + // All items in the batch are contained in WIP, so it is safe to discard the + // batch entirely and re-populate from WIP. + while (!batch.isEmpty()) { + batch.clear(); + } + + // Unlike during normal operation, we will not stop when batch.isFull(). + // Batch#add guarantees that data will not be discarded in the event of + // an overflow -- all extra items are placed into the backlog, which is + // unbounded. + wip.values().forEach(task -> batch.add(task.data())); + } + + reconnectNow(); + } + @Override public void onEvent(Event event) { - if (requireNonNull(event, "event is null") == Event.SHUTDOWN) { - return; + assert retries <= maxRetries : "maxRetries exceeded"; + + if (event == Event.STARTED) { + setState(ACTIVE); + } else if (event instanceof Event.StreamHangup) { + if (retries == maxRetries) { + onEvent(new ClientError(new IOException("Server unavailable"))); + } else { + reconnectAfter(1 * 2 ^ retries); + } } - super.onEvent(event); + + assert retries <= maxRetries : "maxRetries exceeded"; + } + + private void reconnectNow() { + reconnectAfter(0); + } + + private void reconnectAfter(long delaySeconds) { + retries++; + + scheduledExec.schedule(() -> { + try { + reconnect(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + onEvent(new Event.ClientError(e)); + } + + }, delaySeconds, TimeUnit.SECONDS); } } - /** - * StreamAborted means the RPC is "dead": the {@link messages} stream is closed - * and using it will result in an {@link IllegalStateException}. - */ - private final class StreamAborted extends BaseState { - private final Throwable t; + private final class Closing extends BaseState { + /** Service executor for polling {@link #workers} status before closing. */ + private final ExecutorService exec = Executors.newSingleThreadExecutor(); + + /** closed completes the stream. */ + private final CompletableFuture future = new CompletableFuture<>(); - protected StreamAborted(Throwable t) { - super("STREAM_ABORTED"); - this.t = t; + private final Optional ex; + + private Closing(Exception ex) { + super("CLOSING"); + this.ex = Optional.ofNullable(ex); } @Override public void onEnter(State prev) { - // FIXME(dyma): we need to differentiate between "ABORT" the stream - // and "RECONNECT" on server hangup. - // ABORT should be for sender errors, where we cannot fix something - // by reconnecting. This should be called InternalError state. - // - // RECONNECT deals with StreamHangup event (from Recv#onError). - // We should just reconnect and re-submit all WIP items. - abort(t); + exec.execute(() -> { + try { + stopSend(); + workers.await(); + future.complete(null); + } catch (Exception e) { + future.completeExceptionally(e); + } + }); } @Override public void onEvent(Event event) { - // StreamAborted cannot transition into another state. It is terminal -- - // BatchContext MUST terminate its subprocesses and close exceptionally. + if (event != Event.EOF) { + super.onEvent(event); // falthrough + } + } + + private void stopSend() throws InterruptedException { + if (ex.isEmpty()) { + queue.put(TaskHandle.POISON); + } else { + messages.onError(Status.INTERNAL.withCause(ex.get()).asRuntimeException()); + send.cancel(true); + } + } + + void await() throws InterruptedException, ExecutionException { + future.get(); + } + + List shutdownNow() { + return exec.shutdownNow(); } } + + final State CLOSED = new BaseState("CLOSED") { + @Override + public void onEnter(State prev) { + closed = true; + } + }; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java index 03555907c..e5c2dd005 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java @@ -9,18 +9,20 @@ import io.grpc.Status; import io.weaviate.client6.v1.api.collections.batch.Event.Acks; import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; +import io.weaviate.client6.v1.api.collections.batch.Event.ClientError; +import io.weaviate.client6.v1.api.collections.batch.Event.Oom; import io.weaviate.client6.v1.api.collections.batch.Event.Results; -import io.weaviate.client6.v1.api.collections.batch.Event.RpcError; import io.weaviate.client6.v1.api.collections.batch.Event.Started; +import io.weaviate.client6.v1.api.collections.batch.Event.StreamHangup; import io.weaviate.client6.v1.api.collections.batch.Event.TerminationEvent; sealed interface Event - permits Started, Acks, Results, Backoff, TerminationEvent, RpcError { + permits Started, Acks, Results, Backoff, Oom, TerminationEvent, StreamHangup, ClientError { final static Event STARTED = new Started(); final static Event OOM = TerminationEvent.OOM; final static Event SHUTTING_DOWN = TerminationEvent.SHUTTING_DOWN; - final static Event SHUTDOWN = TerminationEvent.SHUTDOWN; + final static Event EOF = TerminationEvent.EOF; /** */ record Started() implements Event { @@ -88,23 +90,25 @@ public String toString() { } } - enum TerminationEvent implements Event { - /** - * Out-Of-Memory. - * - *

    - * Items sent in the previous request cannot be accepted, - * as inserting them may exhaust server's available disk space. - * On receiving this message, the client MUST stop producing - * messages immediately and await {@link #SHUTTING_DOWN} event. - * - *

    - * {@link #OOM} is the sibling of {@link Acks} with the opposite effect. - * The protocol guarantees that the server will respond with either of - * the two, but never both. - */ - OOM, + /** + * Out-Of-Memory. + * + *

    + * Items sent in the previous request cannot be accepted, + * as inserting them may exhaust server's available disk space. + * On receiving this message, the client MUST stop producing + * messages immediately and await {@link #SHUTTING_DOWN} event. + * + *

    + * Oom is the sibling of {@link Acks} with the opposite effect. + * The protocol guarantees that the server will respond with either of + * the two, but never both. + */ + record Oom(int delaySeconds) implements Event { + } + /** Events that are part of the server's graceful shutdown strategy. */ + enum TerminationEvent implements Event { /** * Server shutdown in progress. * @@ -113,27 +117,45 @@ enum TerminationEvent implements Event { * scale-up event (if it previously reported {@link #OOM}) or * some other external event. * On receiving this message, the client MUST stop producing - * messages immediately and close it's side of the stream. + * messages immediately, close it's side of the stream, and + * continue readings server's messages until {@link #EOF}. */ SHUTTING_DOWN, /** - * Server is shutdown. + * Stream EOF. * *

    - * The server has finished the shutdown process and will not - * receive any messages. On receiving this message, the client - * MUST continue reading messages in the stream until the server - * closes it on its end, then re-connect to another instance + * The server has will not receive any messages. If the client + * has more data to send, it SHOULD re-connect to another instance * by re-opening the stream and continue processing the batch. + * If the client has previously sent {@link Message#STOP}, it can + * safely exit. */ - SHUTDOWN; + EOF; } - record RpcError(Exception exception) implements Event { - static RpcError fromThrowable(Throwable t) { + /** + * StreamHangup means the RPC is "dead": the stream is closed + * and using it will result in an {@link IllegalStateException}. + */ + record StreamHangup(Exception exception) implements Event { + static StreamHangup fromThrowable(Throwable t) { Status status = Status.fromThrowable(t); - return new RpcError(status.asException()); + return new StreamHangup(status.asException()); } } + + /** + * ClientError means a client-side exception has happened, + * and is meant primarily for the "send" thread to propagate + * any exception it might catch. + * + *

    + * This MUST be treated as an irrecoverable condition, because + * it is likely caused by an internal issue (an NPE) or a bad + * input ({@link DataTooBigException}). + */ + record ClientError(Exception exception) implements Event { + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java index ea93afffe..ab8ad1149 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java @@ -168,13 +168,8 @@ public int timesRetried() { return retries; } - /** Task creation timestamp. Retried tasks have different timestamps. */ - public Instant createdAt() { - return createdAt; - } - @Override public String toString() { - return "TaskHandle".formatted(id(), timesRetried(), createdAt()); + return "TaskHandle".formatted(id(), retries, createdAt); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java index 530ed31e8..a67780c7e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java @@ -96,9 +96,10 @@ public void onNext(BatchStreamReply reply) { case SHUTTING_DOWN: event = Event.SHUTTING_DOWN; case SHUTDOWN: - event = Event.SHUTDOWN; + event = Event.EOF; case OUT_OF_MEMORY: - event = Event.OOM; + // TODO(dyma): read this value from the message + event = new Event.Oom(300); case BACKOFF: event = new Event.Backoff(reply.getBackoff().getBatchSize()); case ACKS: From 3bd24249e8860f7b36d89aff02ef34d24cba4360 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 17 Feb 2026 14:58:56 +0100 Subject: [PATCH 09/22] chore(batch): fix type Messeger -> Messenger --- .../v1/api/collections/batch/TranslatingStreamFactory.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java index a67780c7e..5b60600c0 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java @@ -31,7 +31,7 @@ class TranslatingStreamFactory implements StreamFactory { @Override public StreamObserver createStream(StreamObserver recv) { - return new Messeger(protoFactory.createStream(new Eventer(recv))); + return new Messenger(protoFactory.createStream(new Eventer(recv))); } /** @@ -64,8 +64,8 @@ public void onError(Throwable t) { * * @see Message */ - private final class Messeger extends DelegatingStreamObserver { - private Messeger(StreamObserver delegate) { + private final class Messenger extends DelegatingStreamObserver { + private Messenger(StreamObserver delegate) { super(delegate); } From 825bd6542fb909907c240ace85489efc90b5a41c Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 14:26:08 +0100 Subject: [PATCH 10/22] fix(batch): perform OOM reconnect sequence via BaseState, not BatchContext#onEvent --- .../api/collections/batch/BatchContext.java | 80 ++++++++++++++++--- 1 file changed, 70 insertions(+), 10 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 4eb34d4a2..6bdf26e1b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -147,16 +147,19 @@ public final class BatchContext implements Closeable { /** Closing state. */ private volatile Closing closing; + /** + * setClosing trasitions BatchContext to {@link Closing} state exactly once. + * Once this method returns, the caller can call {@code closing.await()}. + */ void setClosing(Exception ex) { if (closing == null) { synchronized (Closing.class) { if (closing == null) { closing = new Closing(ex); + setState(closing); } } } - - setState(closing); } BatchContext( @@ -221,9 +224,18 @@ public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException { return add(taskHandle.retry()); } + /** + * Close attempts to drain the queue and send all remaining items. + * Calling any of BatchContext's public methods afterwards will + * result in an {@link IllegalStateException}. + * + * @throws IOException Propagates an exception + * if one has occurred in the meantime. + */ @Override public void close() throws IOException { setClosing(null); + assert closing != null : "closing state not set"; try { closing.await(); @@ -270,7 +282,16 @@ void setState(State nextState) { } } - /** onEvent delegates event handling to {@link #state}. */ + /** + * onEvent delegates event handling to {@link #state}. + * + *

    + * Be mindful that most of the time this callback will run in a hot path + * on a gRPC thread. {@link State} implementations SHOULD offload any + * blocking operations to one of the provided executors. + * + * @see #scheduledExec + */ private void onEvent(Event event) { lock.lock(); try { @@ -487,16 +508,29 @@ public void onEvent(Event event) { }; private class BaseState implements State { + /** State's display name for logging. */ private final String name; + /** Actions permitted in this state. */ private final EnumSet permitted; enum Action { - PREPARE_NEXT, SEND; + /** + * Thy system is allowed to accept new items from the user + * and populate the next batch. + */ + PREPARE_NEXT, + + /** The system is allowed to send the next batch once it's ready. */ + SEND; } - protected BaseState(String name, Action... actions) { + /** + * @param name Display name. + * @param permitted Actions permitted in this state. + */ + protected BaseState(String name, Action... permitted) { this.name = name; - this.permitted = EnumSet.copyOf(Arrays.asList(requireNonNull(actions, "actions is null"))); + this.permitted = EnumSet.copyOf(Arrays.asList(requireNonNull(permitted, "actions is null"))); } @Override @@ -517,7 +551,7 @@ public boolean canPrepareNext() { * Handle events which may arrive at any moment without violating the protocol. * *

      - *
    • {@link Event.Results} -- update tasks in {@link wip} and remove them. + *
    • {@link Event.Results} -- update tasks in {@link #wip} and remove them. *
    • {@link Event.Backoff} -- adjust batch size. *
    • {@link Event#SHUTTING_DOWN} -- transition into * {@link ServerShuttingDown}. @@ -597,11 +631,17 @@ public void onEnter(State prev) { /** Imitate server shutdown sequence. */ private void initiateShutdown() { + // We cannot route event handling via normal BatchContext#onEvent, because + // it delegates to the current state, which is Oom. If Oom#onEvent were to + // receive an Event.SHUTTING_DOWN, it would cancel this execution of this + // very sequence. Instead, we delegate to our parent BaseState which normally + // handles these events. if (Thread.currentThread().isInterrupted()) { - return; + super.onEvent(Event.SHUTTING_DOWN); + } + if (Thread.currentThread().isInterrupted()) { + super.onEvent(Event.EOF); } - onEvent(Event.SHUTTING_DOWN); - onEvent(Event.EOF); } @Override @@ -657,6 +697,16 @@ public void onEnter(State prev) { } } + /** + * Reconnecting state is entererd either by the server finishing a shutdown + * and closing it's end of the stream or an unexpected stream hangup. + * + *

      + * + * + * @see Recv#onCompleted graceful server shutdown + * @see Recv#onError stream hangup + */ private final class Reconnecting extends BaseState { private final int maxRetries; private int retries = 0; @@ -706,10 +756,20 @@ public void onEvent(Event event) { assert retries <= maxRetries : "maxRetries exceeded"; } + /** Reconnect with no delay. */ private void reconnectNow() { reconnectAfter(0); } + /** + * Schedule a task to {@link #reconnect} after a delay. + * + * @param delaySeconds Delay in seconds. + * + * @apiNote The task is scheduled on {@link #scheduledExec} even if + * {@code delaySeconds == 0} to avoid blocking gRPC worker thread, + * where the {@link BatchContext#onEvent} callback runs. + */ private void reconnectAfter(long delaySeconds) { retries++; From d60ed54d6d4ae4e4be4da7bd72ce8f6ae75a9160 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 15:18:32 +0100 Subject: [PATCH 11/22] feat(batch): reconnect to GCP every 160 seconds --- .../io/weaviate/client6/v1/api/Config.java | 14 ++------ .../api/collections/batch/BatchContext.java | 33 ++++++++++++++++--- .../v1/api/collections/batch/Event.java | 11 +++++-- .../batch/WeaviateBatchClient.java | 16 ++++++++- .../client6/v1/internal/TransportOptions.java | 16 +++++++++ .../internal/grpc/DefaultGrpcTransport.java | 5 +++ .../v1/internal/grpc/GrpcTransport.java | 2 ++ .../testutil/transport/MockGrpcTransport.java | 6 ++++ 8 files changed, 83 insertions(+), 20 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index 116454893..33baabb63 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -11,6 +11,7 @@ import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.TransportOptions; import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; import io.weaviate.client6.v1.internal.rest.RestTransportOptions; @@ -181,17 +182,6 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) { private static final String HEADER_X_WEAVIATE_CLUSTER_URL = "X-Weaviate-Cluster-URL"; private static final String HEADER_X_WEAVIATE_CLIENT = "X-Weaviate-Client"; - /** - * isWeaviateDomain returns true if the host matches weaviate.io, - * semi.technology, or weaviate.cloud domain. - */ - private static boolean isWeaviateDomain(String host) { - var lower = host.toLowerCase(); - return lower.contains("weaviate.io") || - lower.contains("semi.technology") || - lower.contains("weaviate.cloud"); - } - private static final String VERSION = "weaviate-client-java/" + ((!BuildInfo.TAGS.isBlank() && BuildInfo.TAGS != "unknown") ? BuildInfo.TAGS : (BuildInfo.BRANCH + "-" + BuildInfo.COMMIT_ID_ABBREV)); @@ -200,7 +190,7 @@ private static boolean isWeaviateDomain(String host) { public Config build() { // For clusters hosted on Weaviate Cloud, Weaviate Embedding Service // will be available under the same domain. - if (isWeaviateDomain(httpHost) && authentication != null) { + if (TransportOptions.isWeaviateDomain(httpHost) && authentication != null) { setHeader(HEADER_X_WEAVIATE_CLUSTER_URL, "https://" + httpHost + ":" + httpPort); } setHeader(HEADER_X_WEAVIATE_CLIENT, VERSION); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 6bdf26e1b..6d8e1cdf5 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -174,6 +174,7 @@ void setClosing(Exception ex) { this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes); setState(CLOSED); + } /** Add {@link WeaviateObject} to the batch. */ @@ -701,9 +702,6 @@ public void onEnter(State prev) { * Reconnecting state is entererd either by the server finishing a shutdown * and closing it's end of the stream or an unexpected stream hangup. * - *

      - * - * * @see Recv#onCompleted graceful server shutdown * @see Recv#onError stream hangup */ @@ -781,7 +779,6 @@ private void reconnectAfter(long delaySeconds) { } catch (ExecutionException e) { onEvent(new Event.ClientError(e)); } - }, delaySeconds, TimeUnit.SECONDS); } } @@ -844,4 +841,32 @@ public void onEnter(State prev) { closed = true; } }; + + // -------------------------------------------------------------------------- + + private final ScheduledExecutorService reconnectExec = Executors.newScheduledThreadPool(1); + + void scheduleReconnect(int reconnectIntervalSeconds) { + reconnectExec.scheduleWithFixedDelay(() -> { + if (Thread.currentThread().isInterrupted()) { + onEvent(Event.SHUTTING_DOWN); + } + if (Thread.currentThread().isInterrupted()) { + onEvent(Event.EOF); + } + + // We want to count down from the moment we re-opened the stream, + // not from the moment we initialited the sequence. + lock.lock(); + try { + while (state != ACTIVE) { + stateChanged.await(); + } + } catch (InterruptedException ignored) { + // Let the process exit normally. + } finally { + lock.unlock(); + } + }, reconnectIntervalSeconds, reconnectIntervalSeconds, TimeUnit.SECONDS); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java index e5c2dd005..ed7d318b0 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java @@ -5,6 +5,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.OptionalInt; import io.grpc.Status; import io.weaviate.client6.v1.api.collections.batch.Event.Acks; @@ -20,12 +21,16 @@ sealed interface Event permits Started, Acks, Results, Backoff, Oom, TerminationEvent, StreamHangup, ClientError { final static Event STARTED = new Started(); - final static Event OOM = TerminationEvent.OOM; final static Event SHUTTING_DOWN = TerminationEvent.SHUTTING_DOWN; final static Event EOF = TerminationEvent.EOF; - /** */ - record Started() implements Event { + /** + * The server has acknowledged our Start message and is ready to receive data. + * + * @param reconnectAfterSeconds Delay in seconds after which + * the stream should be renewed. + */ + record Started(OptionalInt reconnectAfterSeconds) implements Event { } /** diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java index 551d3adcb..32953f74c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java @@ -5,6 +5,7 @@ import java.util.OptionalInt; import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.TransportOptions; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; @@ -34,11 +35,24 @@ public BatchContext start() { if (maxSizeBytes.isEmpty()) { throw new IllegalStateException("Server must have grpcMaxMessageSize configured to use server-side batching"); } + StreamFactory streamFactory = new TranslatingStreamFactory(grpcTransport::createStream); - return new BatchContext<>( + BatchContext context = new BatchContext<>( streamFactory, maxSizeBytes.getAsInt(), collectionDescriptor, defaults); + + if (isWeaviateCloudOnGoogleCloud(grpcTransport.host())) { + context.scheduleReconnect(GCP_RECONNECT_INTERVAL_SECONDS); + } + + return context; + } + + private static final int GCP_RECONNECT_INTERVAL_SECONDS = 160; + + private static boolean isWeaviateCloudOnGoogleCloud(String host) { + return TransportOptions.isWeaviateDomain(host) && TransportOptions.isGoogleCloudDomain(host); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java index 897bb28cd..06c0b6c15 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java @@ -57,4 +57,20 @@ public H headers() { public TrustManagerFactory trustManagerFactory() { return this.trustManagerFactory; } + + /** + * isWeaviateDomain returns true if the host matches weaviate.io, + * semi.technology, or weaviate.cloud domain. + */ + public static boolean isWeaviateDomain(String host) { + var lower = host.toLowerCase(); + return lower.contains("weaviate.io") || + lower.contains("semi.technology") || + lower.contains("weaviate.cloud"); + } + + public static boolean isGoogleCloudDomain(String host) { + var lower = host.toLowerCase(); + return lower.contains("gcp"); + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index 55f428bfb..385808ffc 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -177,4 +177,9 @@ public void close() throws Exception { callCredentials.close(); } } + + @Override + public String host() { + return transportOptions.host(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java index 6b4b6a804..76bb7a3be 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java @@ -23,6 +23,8 @@ CompletableFuture performRequ StreamObserver createStream( StreamObserver recv); + String host(); + /** * Maximum inbound/outbound message size supported by the underlying channel. */ diff --git a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java index 5778fbc6d..98af3d227 100644 --- a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java +++ b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java @@ -17,6 +17,7 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest; public class MockGrpcTransport implements GrpcTransport { + private final String host = "example.com"; @FunctionalInterface public interface AssertFunction { @@ -73,4 +74,9 @@ public OptionalInt maxMessageSizeBytes() { // TODO(dyma): implement for tests throw new UnsupportedOperationException("Unimplemented method 'maxMessageSizeBytes'"); } + + @Override + public String host() { + return host; + } } From c0c58989e01737a2afe1a468e0db65fd4f5749ad Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Mon, 16 Feb 2026 18:41:37 +0100 Subject: [PATCH 12/22] fix(it): automatically pick up the latest container version --- .../java/io/weaviate/containers/Weaviate.java | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index 0f463e96f..c12bf9f6e 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -26,7 +26,7 @@ public class Weaviate extends WeaviateContainer { public static final String DOCKER_IMAGE = "semitechnologies/weaviate"; - public static final String LATEST_VERSION = Version.V135.semver.toString(); + public static final String LATEST_VERSION = Version.latest().semver.toString(); public static final String VERSION; static { @@ -41,7 +41,8 @@ public enum Version { V132(1, 32, 24), V133(1, 33, 11), V134(1, 34, 7), - V135(1, 35, 2); + V135(1, 35, 2), + V136(1, 36, "0-rc.0"); public final SemanticVersion semver; @@ -49,9 +50,21 @@ private Version(int major, int minor, int patch) { this.semver = new SemanticVersion(major, minor, patch); } + private Version(int major, int minor, String patch) { + this.semver = new SemanticVersion(major, minor, patch); + } + public void orSkip() { ConcurrentTest.requireAtLeast(this); } + + public static Version latest() { + Version[] versions = Version.class.getEnumConstants(); + if (versions == null) { + throw new IllegalStateException("No versions are defined"); + } + return versions[versions.length - 1]; + } } /** From 55ff367529d6571b730ef7b7ed2ddc1f6f074cf0 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Mon, 16 Feb 2026 20:15:07 +0100 Subject: [PATCH 13/22] ci(test): add v1.36.0-rc.0 to the testing matrix --- .github/workflows/test.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cdcf934fc..50fe6d1b1 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -78,7 +78,8 @@ jobs: strategy: fail-fast: false matrix: - WEAVIATE_VERSION: ["1.32.24", "1.33.11", "1.34.7", "1.35.2"] + WEAVIATE_VERSION: + ["1.32.24", "1.33.11", "1.34.7", "1.35.2", "1.36.0-rc.0"] steps: - uses: actions/checkout@v4 From 64d49987facf428b6b9ffda9fd55d58eb0b0576b Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 22:41:39 +0100 Subject: [PATCH 14/22] chore(pom.xml): update dependencies --- pom.xml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pom.xml b/pom.xml index 7a98a1a84..8414bc832 100644 --- a/pom.xml +++ b/pom.xml @@ -58,20 +58,20 @@ 3.20.0 4.13.2 2.0.3 - 3.27.6 + 3.27.7 1.0.4 5.21.0 2.0.17 1.5.18 5.14.0 2.21 - 11.31.1 + 11.33 5.15.0 - 4.33.4 - 4.33.4 - 1.78.0 - 1.78.0 - 1.78.0 + 4.33.5 + 4.33.5 + 1.79.0 + 1.79.0 + 1.79.0 6.0.53 From f29a274ad356d79f176a64f8254b41644f31e777 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 22:43:01 +0100 Subject: [PATCH 15/22] test(batch): add the '10_000 objects' integration test --- .../io/weaviate/integration/BatchITest.java | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/it/java/io/weaviate/integration/BatchITest.java diff --git a/src/it/java/io/weaviate/integration/BatchITest.java b/src/it/java/io/weaviate/integration/BatchITest.java new file mode 100644 index 000000000..0e57c914c --- /dev/null +++ b/src/it/java/io/weaviate/integration/BatchITest.java @@ -0,0 +1,48 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.util.Map; +import java.util.UUID; + +import org.assertj.core.api.Assertions; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.containers.Container; +import io.weaviate.containers.Weaviate; + +public class BatchITest extends ConcurrentTest { + private static final WeaviateClient client = Container.WEAVIATE.getClient(); + + @BeforeClass + public static void __() { + Weaviate.Version.V136.orSkip(); + } + + @Test + public void test() throws IOException { + var nsThings = ns("Things"); + + var things = client.collections.create( + nsThings, + c -> c.properties(Property.text("letter"))); + + // Act + try (var batch = things.batch.start()) { + for (int i = 0; i < 10_000; i++) { + String uuid = UUID.randomUUID().toString(); + batch.add(WeaviateObject.of(builder -> builder + .uuid(uuid) + .properties(Map.of("letter", uuid.substring(0, 1))))); + } + } catch (InterruptedException e) { + } + + // Assert + Assertions.assertThat(things.size()).isEqualTo(10_000); + } +} From c5ea8bdf0ceec57456a7d21759f31a4fa64fe0e4 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 22:43:32 +0100 Subject: [PATCH 16/22] fix(batch): remove redundat parameter from Event.Started --- .../io/weaviate/client6/v1/api/collections/batch/Event.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java index ed7d318b0..ba3f6b208 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java @@ -5,7 +5,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.OptionalInt; import io.grpc.Status; import io.weaviate.client6.v1.api.collections.batch.Event.Acks; @@ -26,11 +25,8 @@ sealed interface Event /** * The server has acknowledged our Start message and is ready to receive data. - * - * @param reconnectAfterSeconds Delay in seconds after which - * the stream should be renewed. */ - record Started(OptionalInt reconnectAfterSeconds) implements Event { + record Started() implements Event { } /** From 1a4645dc0e94a33caec2e1f81e87378bee5fd8a1 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 22:47:27 +0100 Subject: [PATCH 17/22] fix(batch): use API compatible w/ JDK 17 LinkedHashMap::reversed was first introduced in JDK 21. The ListIterator approach requires allocating a list, but yields a much simpler code in return. --- .../client6/v1/api/collections/batch/Batch.java | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java index 4a9b156d5..60865ab51 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java @@ -5,9 +5,8 @@ import java.time.Instant; import java.util.Collection; import java.util.Comparator; -import java.util.Iterator; import java.util.LinkedHashMap; -import java.util.Map; +import java.util.ListIterator; import java.util.OptionalInt; import java.util.Set; import java.util.TreeSet; @@ -176,15 +175,9 @@ synchronized void setMaxSize(int maxSizeNew) { } // Buffer exceeds the new limit. Move extra items to the backlog (LIFO). - Iterator> extra = buffer.reversed() - .entrySet().stream() - .limit(buffer.size() - maxSize) - .iterator(); - - while (extra.hasNext()) { - Data data = extra.next().getValue(); - addBacklog(data); - extra.remove(); + ListIterator extra = buffer.keySet().stream().toList().listIterator(); + while (extra.hasPrevious() && buffer.size() > maxSize) { + addBacklog(buffer.remove(extra.previous())); } } finally { checkInvariants(); From 270f72617d60f08529b7ffa7767b5cdcf1b07d1a Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 22:48:36 +0100 Subject: [PATCH 18/22] fix(batch): add OPENED state for when the stream hasn't been started --- .../client6/v1/api/collections/batch/BatchContext.java | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 6d8e1cdf5..20da02411 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -173,7 +173,7 @@ void setClosing(Exception ex) { this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes); - setState(CLOSED); + setState(OPENED); } @@ -835,6 +835,13 @@ List shutdownNow() { } } + final State OPENED = new BaseState("OPENED") { + @Override + public void onEnter(State prev) { + closed = false; + } + }; + final State CLOSED = new BaseState("CLOSED") { @Override public void onEnter(State prev) { From 9f28126a461f9955a17bf7329cac154e8a40f479 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 22:49:20 +0100 Subject: [PATCH 19/22] fix(batch): replace of -> ofNullable --- .../client6/v1/api/collections/CollectionHandleDefaults.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java index d9026805a..47ee0dcba 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java @@ -36,7 +36,7 @@ public static Function> none() } public CollectionHandleDefaults(Builder builder) { - this(Optional.of(builder.consistencyLevel), Optional.of(builder.tenant)); + this(Optional.ofNullable(builder.consistencyLevel), Optional.ofNullable(builder.tenant)); } public static final class Builder implements ObjectBuilder { From 101936bc03de669ad8d74ba2455471ddbd83d748 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 23:20:28 +0100 Subject: [PATCH 20/22] fix(batch): create empty EnumSet via noneOf --- .../client6/v1/api/collections/batch/BatchContext.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 20da02411..5318e5b11 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -531,7 +531,9 @@ enum Action { */ protected BaseState(String name, Action... permitted) { this.name = name; - this.permitted = EnumSet.copyOf(Arrays.asList(requireNonNull(permitted, "actions is null"))); + this.permitted = requireNonNull(permitted, "actions is null").length == 0 + ? EnumSet.noneOf(Action.class) + : EnumSet.copyOf(Arrays.asList(permitted)); } @Override From 7c1e2bc0e84d3d1f93b1261d09f6b81454584973 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 18 Feb 2026 23:42:28 +0100 Subject: [PATCH 21/22] fix(batch): start the context before returning it --- .../client6/v1/api/collections/batch/WeaviateBatchClient.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java index 32953f74c..453862f5e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java @@ -47,6 +47,7 @@ public BatchContext start() { context.scheduleReconnect(GCP_RECONNECT_INTERVAL_SECONDS); } + context.start(); return context; } From d1d497d0279e057b1866f4ab08c558bc81b691ea Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Thu, 19 Feb 2026 20:22:01 +0100 Subject: [PATCH 22/22] fix(batch): handle happy path BatchContext can deal with happy path, i.e. no oom, no shutdowns, etc. --- .../v1/api/collections/batch/Batch.java | 1 + .../api/collections/batch/BatchContext.java | 246 +++++++++--------- .../v1/api/collections/batch/TaskHandle.java | 1 + .../batch/TranslatingStreamFactory.java | 9 +- 4 files changed, 138 insertions(+), 119 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java index 60865ab51..f3f7a1282 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java @@ -267,6 +267,7 @@ synchronized Collection clear() { Set removed = Set.copyOf(buffer.keySet()); buffer.clear(); + sizeBytes = 0; if (pendingMaxSize.isPresent()) { setMaxSize(pendingMaxSize.getAsInt()); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java index 5318e5b11..d52745976 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -8,7 +8,6 @@ import java.util.Collection; import java.util.EnumSet; import java.util.List; -import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CancellationException; @@ -53,7 +52,7 @@ * @param the shape of properties for inserted objects. */ public final class BatchContext implements Closeable { - private final int DEFAULT_BATCH_SIZE = 1000; + private final int DEFAULT_BATCH_SIZE = 1_000; private final int DEFAULT_QUEUE_SIZE = 100; private final int MAX_RECONNECT_RETRIES = 5; @@ -81,6 +80,9 @@ public final class BatchContext implements Closeable { */ private final ScheduledExecutorService scheduledExec = Executors.newScheduledThreadPool(1); + /** The thread that created the context. */ + private final Thread parent = Thread.currentThread(); + /** Stream factory creates new streams. */ private final StreamFactory streamFactory; @@ -141,26 +143,32 @@ public final class BatchContext implements Closeable { */ private volatile CountDownLatch workers; + /** closing completes the stream. */ + private final CompletableFuture closing = new CompletableFuture<>(); + + /** Executor for performing the shutdown sequence. */ + private final ExecutorService shutdownExec = Executors.newSingleThreadExecutor(); + /** Lightway check to ensure users cannot send on a closed context. */ private volatile boolean closed; - /** Closing state. */ - private volatile Closing closing; - - /** - * setClosing trasitions BatchContext to {@link Closing} state exactly once. - * Once this method returns, the caller can call {@code closing.await()}. - */ - void setClosing(Exception ex) { - if (closing == null) { - synchronized (Closing.class) { - if (closing == null) { - closing = new Closing(ex); - setState(closing); - } - } - } - } + // /** Closing state. */ + // private volatile Closing closing; + + // /** + // * setClosing trasitions BatchContext to {@link Closing} state exactly once. + // * Once this method returns, the caller can call {@code closing.await()}. + // */ + // void setClosing(Exception ex) { + // if (closing == null) { + // synchronized (Closing.class) { + // if (closing == null) { + // closing = new Closing(ex); + // setState(closing); + // } + // } + // } + // } BatchContext( StreamFactory streamFactory, @@ -173,8 +181,6 @@ void setClosing(Exception ex) { this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes); - setState(OPENED); - } /** Add {@link WeaviateObject} to the batch. */ @@ -194,23 +200,32 @@ public TaskHandle add(BatchReference reference) throws InterruptedException { } void start() { + start(AWAIT_STARTED); + } + + void start(State nextState) { workers = new CountDownLatch(2); messages = streamFactory.createStream(new Recv()); - send = sendExec.submit(new Send()); + // Start the stream and await Started message. messages.onNext(Message.start(collectionHandleDefaults.consistencyLevel())); - setState(AWAIT_STARTED); + setState(nextState); + + // "send" routine must start after the nextState has been set. + send = sendExec.submit(new Send()); } /** * Reconnect waits for "send" and "recv" streams to exit * and restarts the process with a new stream. + * + * @param reconnecting Reconnecting instance that called reconnect. */ - void reconnect() throws InterruptedException, ExecutionException { + void reconnect(Reconnecting reconnecting) throws InterruptedException, ExecutionException { workers.await(); send.get(); - start(); + start(reconnecting); } /** @@ -235,21 +250,72 @@ public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException { */ @Override public void close() throws IOException { - setClosing(null); - assert closing != null : "closing state not set"; + closed = true; try { - closing.await(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (ExecutionException e) { + shutdown(); + } catch (InterruptedException | ExecutionException e) { + if (e instanceof InterruptedException || + e.getCause() instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } throw new IOException(e.getCause()); } finally { shutdownExecutors(); - setState(CLOSED); } } + private void shutdown() throws InterruptedException, ExecutionException { + CompletableFuture gracefulShutdown = CompletableFuture.runAsync(() -> { + try { + // Poison the queue -- this will signal "send" to drain the remaing + // items in the batch and in the backlog and exit. + // + // If shutdownNow has been called previously and the "send" routine + // has been interrupted, this may block indefinitely. + // However, shutdownNow ensures that `closing` future is resolved. + queue.put(TaskHandle.POISON); + + // Wait for the send to exit before closing our end of the stream. + send.get(); + messages.onNext(Message.stop()); + messages.onCompleted(); + + // Wait for both "send" and "recv" to exit. + workers.await(); + closing.complete(null); + } catch (Exception e) { + closing.completeExceptionally(e); + } + + }, shutdownExec); + + // Complete shutdown as soon as one of these futures are completed. + // - gracefulShutdown completes if we managed to shutdown normally. + // - closing may complete sooner if shutdownNow is called. + CompletableFuture.anyOf(closing, gracefulShutdown).get(); + } + + private void shutdownNow(Exception ex) { + // Terminate the "send" routine and wait for it to exit. + // Since we're already in the error state we do not care + // much if it throws or not. + send.cancel(true); + try { + send.get(); + } catch (Exception e) { + } + + // Now report this error to the server and close the stream. + closing.completeExceptionally(ex); + messages.onError(Status.INTERNAL.withCause(ex).asRuntimeException()); + + // Since shutdownNow is never triggerred by the "main" thread, + // it may be blocked on trying to add to the queue. While batch + // context is active, we own this thread and may interrupt it. + parent.interrupt(); + } + private void shutdownExecutors() { BiConsumer> assertEmpty = (name, pending) -> { assert pending.isEmpty() : "'%s' service had %d tasks awaiting execution" @@ -264,8 +330,8 @@ private void shutdownExecutors() { pending = scheduledExec.shutdownNow(); assertEmpty.accept("oom", pending); - pending = closing.shutdownNow(); - assertEmpty.accept("close", pending); + pending = shutdownExec.shutdownNow(); + assertEmpty.accept("shutdown", pending); } /** Set the new state and notify awaiting threads. */ @@ -273,6 +339,7 @@ void setState(State nextState) { requireNonNull(nextState, "nextState is null"); lock.lock(); + System.out.println("setState " + state + " => " + nextState); try { State prev = state; state = nextState; @@ -296,6 +363,7 @@ void setState(State nextState) { private void onEvent(Event event) { lock.lock(); try { + System.out.println("onEvent " + event); state.onEvent(event); } finally { lock.unlock(); @@ -337,12 +405,14 @@ private void trySend() { while (!Thread.currentThread().isInterrupted()) { if (batch.isFull()) { + System.out.println("==[send batch]==>"); send(); } TaskHandle task = queue.take(); if (task == TaskHandle.POISON) { + System.out.println("took POISON"); drain(); return; } @@ -354,17 +424,11 @@ private void trySend() { assert existing == null : "duplicate tasks in progress, id=" + existing.id(); } } catch (InterruptedException ignored) { - // This thread is only interrupted in the RECONNECTING state, not by - // the user's code. Allow this method to exit normally to close our - // end of the stream. Thread.currentThread().interrupt(); } catch (Exception e) { onEvent(new Event.ClientError(e)); return; } - - messages.onNext(Message.stop()); - messages.onCompleted(); } /** @@ -489,9 +553,9 @@ public void onEvent(Event event) { requireNonNull(event, "event is null"); if (event instanceof Event.Acks acks) { - Collection remaining = batch.clear(); - if (!remaining.isEmpty()) { - throw ProtocolViolationException.incompleteAcks(List.copyOf(remaining)); + Collection removed = batch.clear(); + if (!acks.acked().containsAll(removed)) { + throw ProtocolViolationException.incompleteAcks(List.copyOf(removed)); } acks.acked().forEach(id -> { TaskHandle task = wip.get(id); @@ -559,11 +623,11 @@ public boolean canPrepareNext() { *

    • {@link Event#SHUTTING_DOWN} -- transition into * {@link ServerShuttingDown}. *
    • {@link Event.StreamHangup -- transition into {@link Reconnecting} state. - *
    • {@link Event.ClientError -- transition into {@link Closing} state with - * exception. + *
    • {@link Event.ClientError -- shutdown the service immediately. *
    * * @throws ProtocolViolationException If event cannot be handled in this state. + * @see BatchContext#shutdownNow */ @Override public void onEvent(Event event) { @@ -590,6 +654,9 @@ private final void onResults(Event.Results results) { } private final void onBackoff(Event.Backoff backoff) { + System.out.print("========== BACKOFF =============="); + System.out.print(backoff.maxSize()); + System.out.print("================================="); batch.setMaxSize(backoff.maxSize()); } @@ -599,13 +666,15 @@ private final void onShuttingDown() { private final void onStreamClosed(Event event) { if (event instanceof Event.StreamHangup hangup) { - // TODO(dyma): log error? + hangup.exception().printStackTrace(); + } + if (!send.isDone()) { + setState(new Reconnecting(MAX_RECONNECT_RETRIES)); } - setState(new Reconnecting(MAX_RECONNECT_RETRIES)); } private final void onClientError(Event.ClientError error) { - setClosing(error.exception()); + shutdownNow(error.exception()); } @Override @@ -662,8 +731,8 @@ public void onEvent(Event event) { // so, while onEvent allows InterruptedException to stay responsive, // in practice this thread will only be interrupted by the thread pool, // which already knows it's being shut down. - } catch (ExecutionException e) { - throw new RuntimeException(e); + } catch (ExecutionException ex) { + onEvent(new Event.ClientError(ex)); } } super.onEvent(event); @@ -718,6 +787,13 @@ private Reconnecting(int maxRetries) { @Override public void onEnter(State prev) { + // The reconnected state is re-set every time the stream restarts. + // This ensures that onEnter hook is only called the first + // time we enter Reconnecting state. + if (prev == this) { + return; + } + send.cancel(true); if (prev.getClass() != ServerShuttingDown.class) { @@ -747,7 +823,7 @@ public void onEvent(Event event) { setState(ACTIVE); } else if (event instanceof Event.StreamHangup) { if (retries == maxRetries) { - onEvent(new ClientError(new IOException("Server unavailable"))); + onEvent(new Event.ClientError(new IOException("Server unavailable"))); } else { reconnectAfter(1 * 2 ^ retries); } @@ -775,7 +851,7 @@ private void reconnectAfter(long delaySeconds) { scheduledExec.schedule(() -> { try { - reconnect(); + reconnect(this); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (ExecutionException e) { @@ -785,72 +861,6 @@ private void reconnectAfter(long delaySeconds) { } } - private final class Closing extends BaseState { - /** Service executor for polling {@link #workers} status before closing. */ - private final ExecutorService exec = Executors.newSingleThreadExecutor(); - - /** closed completes the stream. */ - private final CompletableFuture future = new CompletableFuture<>(); - - private final Optional ex; - - private Closing(Exception ex) { - super("CLOSING"); - this.ex = Optional.ofNullable(ex); - } - - @Override - public void onEnter(State prev) { - exec.execute(() -> { - try { - stopSend(); - workers.await(); - future.complete(null); - } catch (Exception e) { - future.completeExceptionally(e); - } - }); - } - - @Override - public void onEvent(Event event) { - if (event != Event.EOF) { - super.onEvent(event); // falthrough - } - } - - private void stopSend() throws InterruptedException { - if (ex.isEmpty()) { - queue.put(TaskHandle.POISON); - } else { - messages.onError(Status.INTERNAL.withCause(ex.get()).asRuntimeException()); - send.cancel(true); - } - } - - void await() throws InterruptedException, ExecutionException { - future.get(); - } - - List shutdownNow() { - return exec.shutdownNow(); - } - } - - final State OPENED = new BaseState("OPENED") { - @Override - public void onEnter(State prev) { - closed = false; - } - }; - - final State CLOSED = new BaseState("CLOSED") { - @Override - public void onEnter(State prev) { - closed = true; - } - }; - // -------------------------------------------------------------------------- private final ScheduledExecutorService reconnectExec = Executors.newScheduledThreadPool(1); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java index ab8ad1149..aad9c8a71 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java @@ -134,6 +134,7 @@ void setError(String error) { */ private void setResult(Result result) { if (!acked.isDone()) { + // TODO(dyma): can this happen due to us? throw new IllegalStateException("Result can only be set for an ack'ed task"); } this.result.complete(result); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java index 5b60600c0..2e6344ada 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java @@ -93,19 +93,25 @@ public void onNext(BatchStreamReply reply) { switch (reply.getMessageCase()) { case STARTED: event = Event.STARTED; + break; case SHUTTING_DOWN: event = Event.SHUTTING_DOWN; + break; case SHUTDOWN: event = Event.EOF; + break; case OUT_OF_MEMORY: // TODO(dyma): read this value from the message event = new Event.Oom(300); + break; case BACKOFF: event = new Event.Backoff(reply.getBackoff().getBatchSize()); + break; case ACKS: Stream uuids = reply.getAcks().getUuidsList().stream(); Stream beacons = reply.getAcks().getBeaconsList().stream(); event = new Event.Acks(Stream.concat(uuids, beacons).toList()); + break; case RESULTS: List successful = reply.getResults().getSuccessesList().stream() .map(detail -> { @@ -130,8 +136,9 @@ public void onNext(BatchStreamReply reply) { }) .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); event = new Event.Results(successful, errors); + break; case MESSAGE_NOT_SET: - throw new IllegalArgumentException("Message not set"); + throw new ProtocolViolationException("Message not set"); } delegate.onNext(event);