背景ray 简介对 trace 的需求探索历程skywalking-pythonray + skywalking-python执行 demoray remote() 函数解析sw-ray-plugin切点插件代码插件测试打通 trace总结
Ray 提供的能力:
分布式计算框架:能够几乎无感地将原本在单机上运行的程序(尤其是计算密集或数据密集型的任务)轻松扩展到集群中,充分利用集群或核心资源进行并行计算
类比 mar、spark、flink 等分布式框架,本质上都是将程序从单机拓展到分布式
专注于 ai:用于数据处理、模型训练、超参数调优和模型服务等常见机器学习场景的工具
统一资源管理:通过自动资源处理从笔记本电脑无缝扩展到云端
Ray 解决的问题:
例如以下问题:
for 循环要跑 10 万次,每次计算都很耗时,但彼此独立,每次执行之间互不干扰
训练机器学习模型时,参数调优需要同时跑几十个 job。
模型推理服务,请求量巨大,一台机器根本扛不住。
传统的做法是写多线程/多进程,但管理起来非常复杂,而且无法扩展到多台机器。
Ray 就是为了解决这些问题而生的。它提供了一个简单直观的 API,是能够用最少的代码修改,就能获得分布式计算的能力。
核心优势:
简单易用: 几个装饰器就能实现并行和分布式。
高性能: 底层用 C++ 优化,任务调度非常高效。
生态丰富: 提供了用于机器学习全流程的一系列库。
可扩展性: 从笔记本电脑到大型集群无缝扩展。
主要是用于 debug 和 monitor,以及打通现有服务之间的 trace。
目前 Ray 支持了 OpenTelemetry 协议的 trace,这很好,但是考虑到公司目前 trace 体系基于 skywalking,需要考虑兼容的问题。
初步构想的一些方案:
引入 skywalkig-oap:skywalking 对 OpenTelemetry 的有限兼容,使其能够通过 OAP 接受 OpenTelemetry 协议的指标,但在我们的场景中仅使用到了 skywalking-sniffer 的能力,由此在不改变使用 skywalking 的同时,可以通过引入 OAP 来支持,将目前 trace 部分迁移到 OAP 中。
从 skywalking 迁移至 OpenTelemetry:不得不说,OpenTelemetry 是未来可观测领域的事实标准,但将公司内部的 trace 体系升级为基于 OpenTelemetry 协议,这个改造成本较高,而且由于目前对 skywalking 的依赖(对 sniffer 的 n 多魔改),这几乎是所有方案中成本最高的
ray 服务接入 skywalking-python:不使用 ray 原生的 trace 能力,通过接入 skywalking-python 支持 trace。但有以下问题需解决:
开源 skywalking-python 不支持 ray,无对应插件,需自行实现
目前公司 skywalking 版本为 v6.6.0 对应协议版本 sw6,这是 2019 年的版本,已经相对落后,skywalking-python 仅支持 sw8 版本。考虑到 skywalking 的升级成本(大量魔改),对接 trace 时需进行协议转换。
后续主要目标是:
能够采集在 ray 上部署的 python 服务或大模型产生的 trace,并融合到公司目前的 trace 体系中去
具体需要做的事:
skywalking-python:调研 skywalking-python,了解内部实现原理
ray + swkwalking-python:初步尝试编写 ray + swkwalking-python demo
sw-ray-plugin:新增 ray plugin 采集 ray 内部服务间调用 trace
打通 trace:对 skywalking-python trace 部分改造(skywalking v8),打通目前公司 trace 模型(skywalking v6)
项目地址:skywalking-python
直接 clone 到本地启动 demo(commit id:fb3fb005650e2489164978b7804117c7ade1529a
)
找到 /demo/docker-compose.yaml
经测试,需将其中的 kafka.image
修改为 confluentinc/cp-kafka:7.2.15
否则启动会报错
启动所有 service:
启动 /demo/flask_provider_single.py
发起请求:
1curl http://127.0.0.1:9999
打开 oap-ui,查看 trace:
创建 demo 程序并启动:
x1import time
2
3import pydevd_pycharm
4import skywalking.trace.context
5import starlette
6from fastapi import FastAPI
7
8from ray import serve
9from ray.serve.handle import DeploymentHandle
10
11from skywalking import agent, config
12
13app = FastAPI()
14
15deployment .
16class Downstream:
17
18 def __init__(self):
19 config.init(agent_collector_backend_services='localhost:11800', agent_protocol='grpc',
20 agent_name='ray-downstream',
21 kafka_bootstrap_servers='localhost:9094', # If you use kafka, set this
22 agent_instance_name='instance-downstream',
23 agent_experimental_fork_support=True,
24 agent_logging_level='DEBUG',
25 agent_log_reporter_active=True,
26 agent_meter_reporter_active=True,
27 agent_profile_active=True)
28 agent.start()
29
30 async def hello(self) -> str:
31 ctx = skywalking.trace.context.get_context()
32 print(f"downstream traceId: {ctx.segment.related_traces[0]}")
33 return "im' downstream"
34
35
36deployment .
37ingress(app) .
38class Upstream:
39
40 def __init__(self, downstream: DeploymentHandle):
41 config.init(agent_collector_backend_services='localhost:11800', agent_protocol='grpc',
42 agent_name='ray-upstream',
43 kafka_bootstrap_servers='localhost:9094', # If you use kafka, set this
44 agent_instance_name='instance-upstream',
45 agent_experimental_fork_support=True,
46 agent_logging_level='DEBUG',
47 agent_log_reporter_active=True,
48 agent_meter_reporter_active=True,
49 agent_profile_active=True)
50 agent.start()
51 self.downstream = downstream
52
53 post("/upstream") .
54 async def upstream(self, request: starlette.requests.Request) -> str:
55 # pydevd_pycharm.settrace('localhost', port=5331, stdoutToServer=True, stderrToServer=True)
56 resp = await self.downstream.hello.remote()
57 ctx = skywalking.trace.context.get_context()
58 print(f"upstream traceId: {ctx.segment.related_traces[0]}")
59 return f"downstream resp: {resp}"
60
61
62ser = serve.run(Upstream.bind(Downstream.bind()), route_prefix="/hello-ray")
63
64while True:
65 time.sleep(1)
66
执行请求:
xxxxxxxxxx
11curl -X POST http://127.0.0.1:8000/hello-ray/upstream
可以看到,只有一个 upstream 程序的 span 信息,并未看到 downstream 相关的数据
这可能是由于 self.downstream.hello.remote()
指令内部未被 skywalking agent 拦截或是其他原因导致
理论上 ray 内部服务是通过 gRPC 调用的,而且 skwaylking-python 也对 gRPC 进行了增强,那么为什么会采集不到 downstream 的 trace 数据呢?
需进一步排查和定位问题
此函数中的 kwargs
代表关键字参数,将通过底层 rpc 框架传递至下游函数
非常适合用来传递 skywalking 的 trace 信息
client 切点:ray.serve.handle.DeploymentHandle.remote
server 切点:ray.serve._private.replica.UserCallableWrapper.call_user_method
xxxxxxxxxx
1511#
2# Licensed to the Apache Software Foundation (ASF) under one or more
3# contributor license agreements. See the NOTICE file distributed with
4# this work for additional information regarding copyright ownership.
5# The ASF licenses this file to You under the Apache License, Version 2.0
6# (the "License"); you may not use this file except in compliance with
7# the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18from typing import Tuple, Dict, Any
19
20from ray.serve._private.common import RequestMetadata
21
22from skywalking import Layer, Component, config
23from skywalking.trace.carrier import Carrier
24from skywalking.trace.context import get_context, NoopContext
25from skywalking.trace.span import NoopSpan
26from skywalking.trace.tags import Tag
27
28link_vector = ['https://docs.ray.io/en/latest/serve/']
29support_matrix = {
30 'ray': {
31 '>=2.0.0': []
32 }
33}
34note = """
35This plugin enables distributed tracing for Ray Serve applications.
36It uses argument injection to propagate tracing context across remote calls
37since Ray Serve doesn't have built-in metadata transport mechanisms.
38"""
39
40
41# Custom tags for Ray Serve
42class TagRayDeployment(Tag):
43 key = 'ray.deployment'
44
45
46class TagRayMethod(Tag):
47 key = 'ray.method'
48
49
50class TagRayNodeId(Tag):
51 key = 'ray.node_id'
52
53
54# Special context key - highly unlikely to conflict with user arguments
55_SW_CONTEXT_KEY = '__skywalking_tracing_context_ray__'
56
57
58def install_client():
59 try:
60 from ray.serve.handle import DeploymentHandle
61 import ray
62 except ImportError:
63 return
64
65 _original_remote = DeploymentHandle.remote
66
67 def _sw_remote(self, *args, **kwargs):
68 deployment_name = getattr(self, 'deployment_name', 'unknown')
69 method_name = getattr(self.handle_options, 'method_name', 'unknown')
70 peer = f"ray-server://{deployment_name}.{method_name}"
71
72 span = get_context().new_exit_span(
73 op=f"ray-remote:{deployment_name}.{method_name}",
74 peer=peer,
75 component=Component.General
76 )
77
78 with span:
79 span.layer = Layer.RPCFramework
80 span.tag(TagRayDeployment(deployment_name))
81 span.tag(TagRayMethod(method_name))
82
83 # Get current Ray node info if available
84 try:
85 node_id = ray.get_runtime_context().get_node_id()
86 if node_id:
87 span.tag(TagRayNodeId(node_id))
88 except Exception:
89 pass
90
91 # Inject tracing context as a special keyword argument
92 carrier = span.inject()
93 sw_context = {}
94 for item in carrier:
95 sw_context[item.key.capitalize()] = item.val
96
97 kwargs[_SW_CONTEXT_KEY] = sw_context
98
99 try:
100 result = _original_remote(self, *args, **kwargs)
101 return result
102 except Exception as e:
103 span.raised()
104 raise
105
106 DeploymentHandle.remote = _sw_remote
107
108
109def install_server():
110 try:
111 from ray.serve._private.replica import UserCallableWrapper
112 import ray
113 except ImportError:
114 return
115
116 _origin_call_user_method = UserCallableWrapper.call_user_method
117
118 async def _sw_call_user_method(self,
119 request_metadata: RequestMetadata,
120 request_args: Tuple[Any],
121 request_kwargs: Dict[str, Any]) -> Any:
122 if request_kwargs[_SW_CONTEXT_KEY] is None:
123 return await _origin_call_user_method(self, request_metadata, request_args, request_kwargs)
124
125 deployment_name = self._deployment_id.name
126 method_name = request_metadata.call_method
127 sw_context = dict.pop(request_kwargs, _SW_CONTEXT_KEY)
128 carrier = Carrier()
129 for item in carrier:
130 item.val = sw_context[item.key.capitalize()]
131 endpoint = f"{deployment_name}.{method_name}"
132 span = NoopSpan(NoopContext()) if config.ignore_http_method_check(endpoint) \
133 else get_context().new_entry_span(op=f"ray-function:{endpoint}", carrier=carrier)
134 with span:
135 span.layer = Layer.RPCFramework
136 span.component = Component.General
137 span.tag(TagRayDeployment(deployment_name))
138 span.tag(TagRayMethod(method_name))
139 try:
140 return await _origin_call_user_method(self, request_metadata, request_args, request_kwargs)
141 except Exception:
142 span.raised()
143 raise
144
145 UserCallableWrapper.call_user_method = _sw_call_user_method
146
147
148def install():
149 install_client()
150 install_server()
151
xxxxxxxxxx
11curl -X POST http://127.0.0.1:8000/hello-ray/upstream
TODO
TODO