Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 77 additions & 25 deletions connectrpc-otel/connectrpc_otel/_interceptor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import time
from contextlib import AbstractContextManager, contextmanager
from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast

from opentelemetry.metrics import MeterProvider, get_meter_provider
from opentelemetry.propagate import get_global_textmap
from opentelemetry.propagators.textmap import Setter, TextMapPropagator, default_setter
from opentelemetry.trace import (
Expand All @@ -12,15 +14,18 @@
get_current_span,
get_tracer_provider,
)
from opentelemetry.util.types import AttributeValue

from connectrpc.errors import ConnectError

from ._semconv import (
CLIENT_ADDRESS,
CLIENT_PORT,
ERROR_TYPE,
RPC_CLIENT_CALL_DURATION,
RPC_METHOD,
RPC_RESPONSE_STATUS_CODE,
RPC_SERVER_CALL_DURATION,
RPC_SYSTEM_NAME,
SERVER_ADDRESS,
SERVER_PORT,
Expand All @@ -31,14 +36,12 @@
if TYPE_CHECKING:
from collections.abc import Iterator, MutableMapping

from opentelemetry.util.types import AttributeValue

from connectrpc.request import RequestContext

REQ = TypeVar("REQ")
RES = TypeVar("RES")

Token: TypeAlias = tuple[AbstractContextManager, Span]
Token: TypeAlias = tuple[AbstractContextManager, Span, float, dict[str, AttributeValue]]

# Workaround bad typing
_DEFAULT_TEXTMAP_SETTER = cast("Setter[MutableMapping[str, str]]", default_setter)
Expand All @@ -52,6 +55,7 @@ def __init__(
*,
propagator: TextMapPropagator | None = None,
tracer_provider: TracerProvider | None = None,
meter_provider: MeterProvider | None = None,
client: bool = False,
) -> None:
"""Creates a new OpenTelemetry interceptor.
Expand All @@ -68,13 +72,51 @@ def __init__(
self._tracer = tracer_provider.get_tracer("connectrpc-otel", __version__)
self._propagator = propagator or get_global_textmap()

meter_provider = meter_provider or get_meter_provider()
meter = meter_provider.get_meter("connectrpc-otel", __version__)

self._call_duration = meter.create_histogram(
name=(RPC_CLIENT_CALL_DURATION if client else RPC_SERVER_CALL_DURATION),
description=f"Measures the duration of an {'outgoing' if client else 'incoming'} Remote Procedure Call (RPC)",
unit="s",
explicit_bucket_boundaries_advisory=[
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.25,
0.5,
0.75,
1,
2.5,
5,
7.5,
10,
],
)

async def on_start(self, ctx: RequestContext) -> Token:
return self.on_start_sync(ctx)

def on_start_sync(self, ctx: RequestContext) -> Token:
cm = self._start_span(ctx)
start_time = time.perf_counter()

rpc_method = f"{ctx.method().service_name}/{ctx.method().name}"
shared_attrs: dict[str, AttributeValue] = {
RPC_SYSTEM_NAME: RpcSystemNameValues.CONNECTRPC.value,
RPC_METHOD: rpc_method,
}

if sa := ctx.server_address():
addr, port = sa.rsplit(":", 1)
shared_attrs[SERVER_ADDRESS] = addr
shared_attrs[SERVER_PORT] = int(port)

cm = self._start_span(ctx, rpc_method, shared_attrs)
span = cm.__enter__()
return cm, span
return cm, span, start_time, shared_attrs

async def on_end(
self, token: Token, ctx: RequestContext, error: Exception | None
Expand All @@ -84,15 +126,28 @@ async def on_end(
def on_end_sync(
self, token: Token, ctx: RequestContext, error: Exception | None
) -> None:
cm, span = token
self._finish_span(span, error)
cm, span, start_time, shared_attrs = token
end_time = time.perf_counter()
error_attrs = self._get_error_attributes(error)
if error_attrs:
span.set_attributes(error_attrs)
# Won't use shared_attrs anymore, no need to copy.
metric_attrs = shared_attrs
if error_attrs:
metric_attrs.update(error_attrs)
self._call_duration.record(end_time - start_time, metric_attrs)
if error:
cm.__exit__(type(error), error, error.__traceback__)
else:
cm.__exit__(None, None, None)

@contextmanager
def _start_span(self, ctx: RequestContext) -> Iterator[Span]:
def _start_span(
self,
ctx: RequestContext,
span_name: str,
shared_attrs: dict[str, AttributeValue],
) -> Iterator[Span]:
parent_otel_ctx = None
if self._client:
span_kind = SpanKind.CLIENT
Expand All @@ -105,30 +160,27 @@ def _start_span(self, ctx: RequestContext) -> Iterator[Span]:
carrier = ctx.request_headers()
parent_otel_ctx = self._propagator.extract(carrier)

rpc_method = f"{ctx.method().service_name}/{ctx.method().name}"
attrs: dict[str, AttributeValue] = shared_attrs.copy()

attrs: MutableMapping[str, AttributeValue] = {
RPC_SYSTEM_NAME: RpcSystemNameValues.CONNECTRPC.value,
RPC_METHOD: rpc_method,
}
if sa := ctx.server_address():
addr, port = sa.rsplit(":", 1)
attrs[SERVER_ADDRESS] = addr
attrs[SERVER_PORT] = int(port)
if ca := ctx.client_address():
addr, port = ca.rsplit(":", 1)
attrs[CLIENT_ADDRESS] = addr
attrs[CLIENT_PORT] = int(port)

with self._tracer.start_as_current_span(
rpc_method, kind=span_kind, attributes=attrs, context=parent_otel_ctx
span_name, kind=span_kind, attributes=attrs, context=parent_otel_ctx
) as span:
yield span

def _finish_span(self, span: Span, error: Exception | None) -> None:
if error:
if isinstance(error, ConnectError):
span.set_attribute(RPC_RESPONSE_STATUS_CODE, error.code.value)
else:
span.set_attribute(RPC_RESPONSE_STATUS_CODE, "unknown")
span.set_attribute(ERROR_TYPE, type(error).__qualname__)
def _get_error_attributes(
self, error: Exception | None
) -> dict[str, AttributeValue] | None:
if not error:
return None

return {
ERROR_TYPE: type(error).__qualname__,
RPC_RESPONSE_STATUS_CODE: error.code.value
if isinstance(error, ConnectError)
else "unknown",
}
2 changes: 2 additions & 0 deletions connectrpc-otel/connectrpc_otel/_semconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
CLIENT_ADDRESS: Final = "client.address"
CLIENT_PORT: Final = "client.port"
ERROR_TYPE: Final = "error.type"
RPC_CLIENT_CALL_DURATION: Final = "rpc.client.call.duration"
RPC_SERVER_CALL_DURATION: Final = "rpc.server.call.duration"
RPC_METHOD: Final = "rpc.method"
RPC_RESPONSE_STATUS_CODE: Final = "rpc.response.status_code"
RPC_SYSTEM_NAME: Final = "rpc.system.name"
Expand Down
Loading