Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
import org.apache.rocketmq.proxy.config.ProxyConfig;
import org.apache.rocketmq.proxy.grpc.v2.AbstractMessingActivity;
import org.apache.rocketmq.proxy.grpc.v2.channel.GrpcChannelManager;
import org.apache.rocketmq.proxy.grpc.v2.channel.GrpcClientChannel;
import org.apache.rocketmq.proxy.grpc.v2.common.GrpcClientSettingsManager;
import org.apache.rocketmq.proxy.grpc.v2.common.GrpcConverter;
import org.apache.rocketmq.proxy.grpc.v2.common.GrpcProxyException;
import org.apache.rocketmq.proxy.processor.MessagingProcessor;
import org.apache.rocketmq.proxy.processor.QueueSelector;
import org.apache.rocketmq.proxy.service.route.AddressableMessageQueue;
Expand Down Expand Up @@ -135,14 +137,22 @@ public void receiveMessage(ProxyContext ctx, ReceiveMessageRequest request,
).thenAccept(popResult -> {
if (proxyConfig.isEnableProxyAutoRenew() && request.getAutoRenew()) {
if (PopStatus.FOUND.equals(popResult.getPopStatus())) {
GrpcClientChannel clientChannel = grpcChannelManager.getChannel(ctx.getClientID());
if (clientChannel == null) {
GrpcProxyException e = new GrpcProxyException(Code.MESSAGE_NOT_FOUND,
String.format("The client [%s] is disconnected.", ctx.getClientID()));
popResult.getMsgFoundList().forEach(messageExt ->
writer.processThrowableWhenWriteMessage(e, ctx, request, messageExt));
throw e;
}
List<MessageExt> messageExtList = popResult.getMsgFoundList();
for (MessageExt messageExt : messageExtList) {
String receiptHandle = messageExt.getProperty(MessageConst.PROPERTY_POP_CK);
if (receiptHandle != null) {
MessageReceiptHandle messageReceiptHandle =
new MessageReceiptHandle(group, topic, messageExt.getQueueId(), receiptHandle, messageExt.getMsgId(),
messageExt.getQueueOffset(), messageExt.getReconsumeTimes());
messagingProcessor.addReceiptHandle(ctx, grpcChannelManager.getChannel(ctx.getClientID()), group, messageExt.getMsgId(), messageReceiptHandle);
messagingProcessor.addReceiptHandle(ctx, clientChannel, group, messageExt.getMsgId(), messageReceiptHandle);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,15 @@ public void writeAndComplete(ProxyContext ctx, ReceiveMessageRequest request, Po
.setStatus(ResponseBuilder.getInstance().buildStatus(Code.MESSAGE_NOT_FOUND, "no match message"))
.build());
} else {
streamObserver.onNext(ReceiveMessageResponse.newBuilder()
.setStatus(ResponseBuilder.getInstance().buildStatus(Code.OK, Code.OK.name()))
.build());
try {
streamObserver.onNext(ReceiveMessageResponse.newBuilder()
.setStatus(ResponseBuilder.getInstance().buildStatus(Code.OK, Code.OK.name()))
.build());
} catch (Throwable t) {
messageFoundList.forEach(messageExt ->
this.processThrowableWhenWriteMessage(t, ctx, request, messageExt));
throw t;
}
Iterator<MessageExt> messageIterator = messageFoundList.iterator();
while (messageIterator.hasNext()) {
MessageExt curMessageExt = messageIterator.next();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,21 @@
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.rocketmq.client.consumer.AckResult;
import org.apache.rocketmq.client.consumer.PopResult;
import org.apache.rocketmq.client.consumer.PopStatus;
import org.apache.rocketmq.common.MixAll;
import org.apache.rocketmq.common.constant.PermName;
import org.apache.rocketmq.common.consumer.ReceiptHandle;
import org.apache.rocketmq.common.message.MessageAccessor;
import org.apache.rocketmq.common.message.MessageConst;
import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.proxy.common.ProxyContext;
import org.apache.rocketmq.proxy.config.ConfigurationManager;
import org.apache.rocketmq.proxy.grpc.v2.BaseActivityTest;
Expand All @@ -61,6 +68,8 @@
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class ReceiveMessageActivityTest extends BaseActivityTest {

Expand Down Expand Up @@ -223,6 +232,87 @@ public void testReceiveMessageIllegalInvisibleTimeTooLarge() {
assertEquals(Code.ILLEGAL_INVISIBLE_TIME, getResponseCodeFromReceiveMessageResponseList(responseArgumentCaptor.getAllValues()));
}

@Test
public void testReceiveMessageAddReceiptHandle() {
ConfigurationManager.getProxyConfig().setEnableProxyAutoRenew(true);
StreamObserver<ReceiveMessageResponse> receiveStreamObserver = mock(ServerCallStreamObserver.class);
doNothing().when(receiveStreamObserver).onNext(any());
when(this.grpcClientSettingsManager.getClientSettings(any())).thenReturn(Settings.newBuilder().getDefaultInstanceForType());

MessageExt messageExt1 = new MessageExt();
String msgId1 = "msgId1";
String popCk1 = "0 0 60000 0 0 broker 0 0 0";
messageExt1.setTopic(TOPIC);
messageExt1.setMsgId(msgId1);
MessageAccessor.putProperty(messageExt1, MessageConst.PROPERTY_POP_CK, popCk1);
messageExt1.setBody("body1".getBytes());
MessageExt messageExt2 = new MessageExt();
String msgId2 = "msgId2";
String popCk2 = "0 0 60000 0 0 broker 0 1 1000";
messageExt2.setTopic(TOPIC);
messageExt2.setMsgId(msgId2);
MessageAccessor.putProperty(messageExt2, MessageConst.PROPERTY_POP_CK, popCk2);
messageExt2.setBody("body2".getBytes());
PopResult popResult = new PopResult(PopStatus.FOUND, Arrays.asList(messageExt1, messageExt2));
when(this.messagingProcessor.popMessage(
any(),
any(),
anyString(),
anyString(),
anyInt(),
anyLong(),
anyLong(),
anyInt(),
any(),
anyBoolean(),
any(),
isNull(),
anyLong())).thenReturn(CompletableFuture.completedFuture(popResult));
ArgumentCaptor<String> msgIdCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<ReceiptHandle> receiptHandleCaptor = ArgumentCaptor.forClass(ReceiptHandle.class);
when(this.messagingProcessor.changeInvisibleTime(
any(),
receiptHandleCaptor.capture(),
msgIdCaptor.capture(),
anyString(),
anyString(),
anyLong())).thenReturn(CompletableFuture.completedFuture(new AckResult()));

// normal
ProxyContext ctx = createContext();
this.grpcChannelManager.createChannel(ctx, ctx.getClientID());
ReceiveMessageRequest receiveMessageRequest = ReceiveMessageRequest.newBuilder()
.setGroup(Resource.newBuilder().setName(CONSUMER_GROUP).build())
.setMessageQueue(MessageQueue.newBuilder().setTopic(Resource.newBuilder().setName(TOPIC).build()).build())
.setAutoRenew(true)
.setFilterExpression(FilterExpression.newBuilder()
.setType(FilterType.TAG)
.setExpression("*")
.build())
.build();
this.receiveMessageActivity.receiveMessage(ctx, receiveMessageRequest, receiveStreamObserver);
verify(this.messagingProcessor, times(0)).changeInvisibleTime(
any(),
any(),
anyString(),
anyString(),
anyString(),
anyLong());

// abnormal
this.grpcChannelManager.removeChannel(ctx.getClientID());
this.receiveMessageActivity.receiveMessage(ctx, receiveMessageRequest, receiveStreamObserver);
verify(this.messagingProcessor, times(2)).changeInvisibleTime(
any(),
any(),
anyString(),
anyString(),
anyString(),
anyLong());
assertEquals(Arrays.asList(msgId1, msgId2), msgIdCaptor.getAllValues());
assertEquals(Arrays.asList(popCk1, popCk2), receiptHandleCaptor.getAllValues().stream().map(ReceiptHandle::encode).collect(Collectors.toList()));
}

@Test
public void testReceiveMessage() {
StreamObserver<ReceiveMessageResponse> receiveStreamObserver = mock(ServerCallStreamObserver.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -90,16 +91,17 @@ public void testWriteMessage() {
messageExtList.add(createMessageExt(TOPIC, "tag"));
messageExtList.add(createMessageExt(TOPIC, "tag"));
PopResult popResult = new PopResult(PopStatus.FOUND, messageExtList);
ReceiveMessageRequest receiveMessageRequest = ReceiveMessageRequest.newBuilder()
.setGroup(Resource.newBuilder().setName(CONSUMER_GROUP).build())
.setMessageQueue(MessageQueue.newBuilder().setTopic(Resource.newBuilder().setName(TOPIC).build()).build())
.setFilterExpression(FilterExpression.newBuilder()
.setType(FilterType.TAG)
.setExpression("*")
.build())
.build();
writer.writeAndComplete(
ProxyContext.create(),
ReceiveMessageRequest.newBuilder()
.setGroup(Resource.newBuilder().setName(CONSUMER_GROUP).build())
.setMessageQueue(MessageQueue.newBuilder().setTopic(Resource.newBuilder().setName(TOPIC).build()).build())
.setFilterExpression(FilterExpression.newBuilder()
.setType(FilterType.TAG)
.setExpression("*")
.build())
.build(),
receiveMessageRequest,
popResult
);

Expand All @@ -114,6 +116,16 @@ public void testWriteMessage() {
assertEquals(messageExtList.get(0).getMsgId(), responseArgumentCaptor.getAllValues().get(1).getMessage().getSystemProperties().getMessageId());

assertEquals(messageExtList.get(1).getMsgId(), changeInvisibleTimeMsgIdCaptor.getValue());

// case: fail to write response status at first step
doThrow(new RuntimeException()).when(streamObserver).onNext(any());
writer.writeAndComplete(
ProxyContext.create(),
receiveMessageRequest,
popResult
);
verify(this.messagingProcessor, times(3))
.changeInvisibleTime(any(), any(), anyString(), anyString(), anyString(), anyLong());
}

@Test
Expand Down
Loading