Source code for pyokx.websocket_handling


"""
This module handles the real-time interaction with the OKX exchange Websocket APIs. It subscribes to various data streams
including public, business, and private channels, processes the incoming data, and optionally stores the data in a Redis
stream for further analysis or decision-making. The main function `okx_websockets_main_run` initializes the websocket
connections, handles message receipt, and maintains a heartbeat to keep the connections alive.
"""

import asyncio
import json
import os
from pprint import pprint

import dotenv

from pyokx.okx_market_maker.market_data_service.WssMarketDataService import on_orderbook_snapshot_or_update
from pyokx.okx_market_maker.market_data_service.model.OrderBook import OrderBook
from pyokx.rest_handling import get_ticker_with_higher_volume
from pyokx.ws_clients.WsPprivateAsync import WsPrivateAsync
from pyokx.ws_clients.WsPublicAsync import WsPublicAsync
from pyokx.ws_data_structures import PriceLimitChannel, InstrumentsChannel, \
    MarkPriceChannel, IndexTickersChannel, MarkPriceCandleSticksChannel, IndexCandleSticksChannel, AccountChannel, \
    PositionsChannel, BalanceAndPositionsChannel, WebSocketConnectionConfig, OrdersChannel, OrderBookChannel, \
    TickersChannel, IndexTickersChannelInputArgs, OrderBookInputArgs, MarkPriceChannelInputArgs, \
    TickersChannelInputArgs, OrdersChannelInputArgs, OKX_WEBSOCKET_URLS, public_channels_available, \
    business_channels_available, private_channels_available, available_channel_models
from redis_tools.utils import serialize_for_redis, get_async_redis
from shared.logging import setup_logger

logger = setup_logger(__name__)

REDIS_STREAM_MAX_LEN = int(os.getenv('REDIS_STREAM_MAX_LEN', 1000))


[docs] async def ws_callback(message): """ Callback function to handle messages received from the websocket. It processes the raw message, converts it into structured data, and stores it in Redis if enabled. Args: message (str): The raw message received from the websocket. """ message_json = json.loads(message) event = message_json.get("event", None) logger.info(f"Incoming Raw WS Message: {message_json}") if event: if event == "error": logger.error(f"Error: {message_json}") return # TODO: Handle events, mostly subscribe and error stop and reconnect task logger.info(f"Event: {message_json}") return # TODO: Handle events, mostly subscribe and error try: ''' ---------------------------------------------------- Send out the Supported Models to their messages stream channel (verifies the channel message has a pydantic model) ---------------------------------------------------- ''' message_args = message_json.get("arg") message_channel = message_args.get("channel") data_struct = available_channel_models[message_channel] if hasattr(data_struct, "from_array"): # only applicatble to a few scenarios with candlesticks structured_message = data_struct.from_array(**message_json) else: structured_message = data_struct(**message_json) logger.info(f"Received Structured-Data: {structured_message}") if _redis_store and async_redis: redis_ready_message = serialize_for_redis(structured_message) await async_redis.xadd(f'okx:websockets@{message_channel}', {'data': redis_ready_message}, maxlen=REDIS_STREAM_MAX_LEN) await async_redis.xadd(f'okx:websockets@all', {'data': redis_ready_message}, maxlen=REDIS_STREAM_MAX_LEN) ''' (ALPHA) ---------------------------------------------------- Handle supported channels data (can be moved to listen to the redistributed redis channel -from-above-) e.g. message = deserialize_from_redis(r.xrevrange('okx:websockets@account', count=1)[0][1]) account: Account = on_account(incoming_account_message) redis_ready_message = serialize_for_redis(account.to_dict()) r.xadd(f'okx:reports@{message.get("arg").get("channel")}', redis_ready_message, maxlen=1000) ---------------------------------------------------- ''' if message_channel == "index-tickers": await async_redis.xadd(f'okx:websockets@{message_channel}@{message_args.get("instId")}', {'data': serialize_for_redis(structured_message)}, maxlen=REDIS_STREAM_MAX_LEN) # await handle_reports(message_json, async_redis) # TODO: Handle reports except Exception as e: logger.warning(f"Exception: {e} \n {message_json}") return # TODO: Handle exceptions
[docs] async def okx_websockets_main_run(input_channel_models: list, apikey: str = None, passphrase: str = None, secretkey: str = None, sandbox_mode: bool = True, redis_store: bool = True): """ Initializes and manages websocket connections to the OKX exchange. It subscribes to the specified channels, processes incoming messages, and optionally stores structured data in a Redis stream. Args: input_channel_models (list): A list of channel input arguments specifying the channels to subscribe. apikey (str, optional): The API key for accessing private channels. passphrase (str, optional): The passphrase associated with the API key. secretkey (str, optional): The secret key for accessing private channels. sandbox_mode (bool, optional): If True, use the sandbox environment. Defaults to True. redis_store (bool, optional): If True, store the received messages in a Redis stream. Defaults to True. Raises: Exception: If no channels are provided or if an invalid channel is specified. """ logger.info(f"Starting OKX Websocket Connections ") if not input_channel_models: raise Exception("No channels provided") public_channels_config = WebSocketConnectionConfig( name='okx_public', wss_url=OKX_WEBSOCKET_URLS["public"] if not sandbox_mode else OKX_WEBSOCKET_URLS["public_demo"], channels=public_channels_available ) business_channels_config = WebSocketConnectionConfig( name='okx_business', wss_url=OKX_WEBSOCKET_URLS["business"] if not sandbox_mode else OKX_WEBSOCKET_URLS["business_demo"], channels=business_channels_available ) private_channels_config = WebSocketConnectionConfig( name='okx_private', wss_url=OKX_WEBSOCKET_URLS["private"] if not sandbox_mode else OKX_WEBSOCKET_URLS["private_demo"], channels=private_channels_available ) if redis_store: global async_redis global _redis_store async_redis = await get_async_redis() _redis_store = redis_store else: async_redis = None public_channels_inputs = [] business_channels_inputs = [] private_channels_inputs = [] for input_channel in input_channel_models: if input_channel.channel in public_channels_config.channels: public_channels_inputs.append(input_channel) elif input_channel.channel in business_channels_config.channels: business_channels_inputs.append(input_channel) elif input_channel.channel in private_channels_config.channels: private_channels_inputs.append(input_channel) else: raise Exception(f"Channel {input_channel.channel} not found in available channels") public_params = [channel.model_dump() for channel in public_channels_inputs] business_params = [channel.model_dump() for channel in business_channels_inputs] private_params = [channel.model_dump() for channel in private_channels_inputs] public_client = None business_client = None private_client = None if public_params: public_client = WsPublicAsync(url=public_channels_config.wss_url, callback=ws_callback) logger.info(f"Subscribing to public channels: {public_params}") await public_client.start() await public_client.subscribe(params=public_params) if business_params: business_client = WsPublicAsync(url=business_channels_config.wss_url, callback=ws_callback) logger.info(f"Subscribing to business channels: {business_params}") await business_client.start() await business_client.subscribe(params=business_params) if private_params: assert apikey, f"API key was not provided" assert secretkey, f"API secret key was not provided" assert passphrase, f"Passphrase was not provided" private_client = WsPrivateAsync(apikey=apikey, passphrase=passphrase, secretkey=secretkey, url=private_channels_config.wss_url, use_servertime=False, callback=ws_callback) logger.info(f"Subscribing to private channels: {private_params}") await private_client.start() await private_client.subscribe(params=private_params) # Keep the loop running, or perform other tasks try: while True: # This could be the main loop of the trading strategy or at least for the health checks await asyncio.sleep(60) logger.debug("Heartbeat \n ___________") # Print stats for redis if redis_store and async_redis: logger.debug(f"Redis Stats: ") # print only the relevant stats that have human in them redis_stats = await async_redis.info() stats_to_print = {k: v for k, v in redis_stats.items() if "human" in k} pprint(stats_to_print) logger.debug("___________ \n Heartbeat") except KeyboardInterrupt: pass except Exception as e: logger.error(f"Exception: {e}") finally: if public_client: try: if public_params: await public_client.unsubscribe(params=public_params) await public_client.stop() except Exception as e: logger.error(f"Exception: {e}") if business_client: try: if business_params: await business_client.unsubscribe(params=business_params) await business_client.stop() except Exception as e: logger.error(f"Exception: {e}") if private_client: try: if private_params: await private_client.unsubscribe(params=private_params) await private_client.stop() except Exception as e: logger.error(f"Exception: {e}") if redis_store and async_redis: try: await async_redis.close() except Exception as e: logger.error(f"Exception: {e}") logger.info("Exiting")
[docs] async def handle_reports(message_json, async_redis): message_args = message_json.get("arg") message_channel = message_args.get("channel") # Position Management Service if message_channel == "balance_and_position": from pyokx.okx_market_maker.position_management_service.WssPositionManagementService import \ on_balance_and_position from pyokx.okx_market_maker.position_management_service.model.BalanceAndPosition import \ BalanceAndPosition balance_and_position: BalanceAndPosition = on_balance_and_position(message_json) redis_ready_message = serialize_for_redis(balance_and_position.to_dict()) await async_redis.xadd(f'okx:reports@balance_and_position', redis_ready_message, maxlen=REDIS_STREAM_MAX_LEN) if message_channel == "account": from pyokx.okx_market_maker.position_management_service.WssPositionManagementService import \ on_account from pyokx.okx_market_maker.position_management_service.model.Account import Account account: Account = on_account(message_json) redis_ready_message = serialize_for_redis(account.to_dict()) await async_redis.xadd(f'okx:reports@account', redis_ready_message, maxlen=REDIS_STREAM_MAX_LEN) if message_channel == "positions": from pyokx.okx_market_maker.position_management_service.WssPositionManagementService import \ on_position from pyokx.okx_market_maker.position_management_service.model.Positions import Positions positions: Positions = on_position(message_json) redis_ready_message = serialize_for_redis(positions.to_dict()) await async_redis.xadd(f'okx:reports@positions', redis_ready_message, maxlen=REDIS_STREAM_MAX_LEN) # Order Management Service if message_channel == "orders": from pyokx.okx_market_maker.order_management_service.WssOrderManagementService import \ on_orders_update from pyokx.okx_market_maker.order_management_service.model.Order import Orders orders: Orders = on_orders_update(message_json) redis_ready_message = serialize_for_redis(orders.to_dict()) await async_redis.xadd(f'okx:reports@orders', redis_ready_message, maxlen=REDIS_STREAM_MAX_LEN) # Market Data Service if message_channel in ["books5", "books", "bbo-tbt", "books50-l2-tbt", "books-l2-tbt"]: books: OrderBook = on_orderbook_snapshot_or_update(message_json) redis_ready_message = serialize_for_redis(books.to_dict()) await async_redis.xadd(f'okx:reports@{message_channel}@{message_args.get("instId")}', redis_ready_message, maxlen=REDIS_STREAM_MAX_LEN) if message_channel == "mark-price": from pyokx.okx_market_maker.market_data_service.WssMarketDataService import on_mark_price_update from pyokx.okx_market_maker.market_data_service.model.MarkPx import MarkPxCache mark_px: MarkPxCache = on_mark_price_update(message_json) redis_ready_message = serialize_for_redis(mark_px.to_dict()) await async_redis.xadd(f'okx:reports@mark-price@{message_args.get("instId")}', redis_ready_message, maxlen=REDIS_STREAM_MAX_LEN) if message_channel == "tickers": from pyokx.okx_market_maker.market_data_service.WssMarketDataService import on_ticker_update from pyokx.okx_market_maker.market_data_service.model.Tickers import Tickers tickers: Tickers = on_ticker_update(message_json) redis_ready_message = serialize_for_redis(tickers.to_dict()) await async_redis.xadd(f'okx:reports@{tickers}', redis_ready_message, maxlen=REDIS_STREAM_MAX_LEN)
[docs] def get_instrument_specific_channel_inputs_to_listen_to(): btc_ = get_ticker_with_higher_volume("BTC-USDT", instrument_type="FUTURES", top_n=1) eth_ = get_ticker_with_higher_volume("ETH-USDT", instrument_type="FUTURES", top_n=1) ltc = get_ticker_with_higher_volume("LTC-USDT", instrument_type="FUTURES", top_n=1) instruments_to_listen_to = btc_ + eth_ + ltc instrument_specific_channels = [] for instrument in instruments_to_listen_to: instrument_specific_channels.append(OrderBookInputArgs(channel="books5", instId=instrument.instId)) instrument_specific_channels.append(OrderBookInputArgs(channel="books", instId=instrument.instId)) instrument_specific_channels.append(OrderBookInputArgs(channel="bbo-tbt", instId=instrument.instId)) instrument_specific_channels.append(MarkPriceChannelInputArgs(channel="mark-price", instId=instrument.instId)) instrument_specific_channels.append(TickersChannelInputArgs(channel="tickers", instId=instrument.instId)) return instrument_specific_channels
[docs] def get_btc_usdt_usd_index_channel_inputs_to_listen_to(): return [ IndexTickersChannelInputArgs(channel="index-tickers", instId="BTC-USDT"), IndexTickersChannelInputArgs(channel="index-tickers", instId="BTC-USD") ]
[docs] async def test_restart(public_client, business_client, private_client): clients = [client for client in [public_client, business_client, private_client] if hasattr(client, "restart")] await asyncio.gather(*[client.restart() for client in clients if client])
if __name__ == '__main__': """ Main execution block. Loads the environment variables, prepares the channel models for subscription, and starts the websocket main run function. """ dotenv.load_dotenv(dotenv.find_dotenv()) instrument_specific_channels = [] btc_usdt_usd_index_channels = [] # instrument_specific_channels = get_channel_inputs_to_listen_to() # btc_usdt_usd_index_channels = get_btc_usdt_usd_index_channel_inputs_to_listen_to() input_channel_models = ( [ ### Business Channels # MarkPriceCandleSticksChannelInputArgs(channel="mark-price-candle1m", instId=instId), # IndexCandleSticksChannelInputArgs(channel="index-candle1m", instId=instFamily), ### Public Channels # # InstrumentsChannelInputArgs(channel="instruments", instType="FUTURES"), # todo handle data # PriceLimitChannelInputArgs(channel="price-limit", instId=instId),# todo handle data # IndexTickersChannelInputArgs( # # Index with USD, USDT, BTC, USDC as the quote currency, e.g. BTC-USDT, e.g. not BTC-USDT-240628 # channel="index-tickers", instId="BTC-USDT"), # IndexTickersChannelInputArgs( # # Index with USD, USDT, BTC, USDC as the quote currency, e.g. BTC-USDT, e.g. not BTC-USDT-240628 # channel="index-tickers", instId="BTC-USD"), ### Private Channels # AccountChannelInputArgs(channel="account", ccy=None), # PositionsChannelInputArgs(channel="positions", instType="ANY", instFamily=None, # instId=None), # BalanceAndPositionsChannelInputArgs(channel="balance_and_position"), OrdersChannelInputArgs(channel="orders", instType="FUTURES", instFamily=None, instId=None) ] + instrument_specific_channels + btc_usdt_usd_index_channels) asyncio.run(okx_websockets_main_run(input_channel_models=input_channel_models, apikey=os.getenv('OKX_API_KEY'), passphrase=os.getenv('OKX_PASSPHRASE'), secretkey=os.getenv('OKX_SECRET_KEY'), # todo/fixme this can differ from the actual trading strategy that is # listening to the streams thus watch out sandbox_mode=os.getenv('OKX_SANDBOX_MODE', True), redis_store=True ) )