Applied RL: Customization of RL policies using StableBaselines3

Akhilesh Gogikar
4 min readJun 6, 2022

--

As explained in the previous post in this series we have developed a multi-stock trading environment using OpenAI’s Gym API.

We can train an RL model using this environment by importing the appropriate RL algorithm from StableBaselines3 which is a community-developed extension to the StableBaselines library originally open-sourced by OpenAI.

from multi_stock_trading_env import MultiStockTradingEnvfrom stable_baselines3 import PPOenv = MultiStockTradingEnv(df_list, price_df, num_stocks=num_assets, initial_amount=1000000, trade_cost=0, num_features=cols_per_asset, window_size=12, frame_bound = (12,len(price_df)-1500), tech_indicator_list=indicators)prices, features = env.process_data()model = PPO('MlpPolicy', env, verbose=2,tensorboard_log='tb_logs', batch_size=256)model.learn(total_timesteps=1000)

We can customize the RL algorithm by writing our own custom policy which can be utilized to train the RL agent.

stablebaselines3 provides an out-of-the-box features extractor and we can customize the fully-connected network.

We import the base ActorCriticPolicy class from the stable_baselines3. Then we update the _build_mlp_extractor method to update the mlp_extractor of the policy. Then we write a custom network to update the mlp_extractor.

from stable_baselines3.common.policies import ActorCriticPolicyclass CustomActorCriticPolicy(ActorCriticPolicy):     def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, *args, **kwargs,):         super(CustomActorCriticPolicy, self).__init__( observation_space, action_space, lr_schedule, net_arch, activation_fn,         # Pass remaining arguments to base class*args,**kwargs,)         # Disable orthogonal initialization    self.ortho_init = False    def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork( last_layer_dim_pi = self.action_space.shape[0], last_layer_dim_vf = self.action_space.shape[0], timesteps = self.observation_space.shape[1], feature_dim=self.observation_space.shape[2])

For the custom neural network for the MLP extractor, we will build a simple our own custom MLP architecture but it will take into account reduction along a specific dimension each time.

We have to keep in mind that the default feature extractor in stablebaselines3 is the flattening operation so our observation input of the form (Num_stocks, Timesteps, Num_Features) is transformed into a 1D array of size (Num_stocks*Timesteps*Num_Features).

Our neural network transforms this input space into an actions array of size Num_stocks.

(Num_stocks*Timesteps*Num_Features) -> nn.Linear -> (Num_stocks*Num_Features) -> Relu() -> nn.Linear() -> (Num_stocks,) -> Tanh()

class CustomNetwork(nn.Module):"""Custom network for policy and value function.It receives as input the features extracted by the feature extractor.:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN):param last_layer_dim_pi: (int) number of units for the last layer of the policy network:param last_layer_dim_vf: (int) number of units for the last layer of the value network"""def __init__(self,feature_dim: int,timesteps: int = 12,last_layer_dim_pi: int = 64,last_layer_dim_vf: int = 64,):    super(CustomNetwork, self).__init__()    # IMPORTANT:    # Save output dimensions, used to create the distributions    self.latent_dim_pi = last_layer_dim_pi    self.latent_dim_vf = last_layer_dim_vf    # Policy network
self.policy_net =
nn.Sequential(
nn.Linear(last_layer_dim_pi*timesteps*feature_dim,
last_layer_dim_pi*feature_dim),
nn.ReLU(),
nn.Linear(last_layer_dim_pi*feature_dim,last_layer_dim_pi),
nn.Tanh())
# Value network
self.value_net = nn.Sequential(
nn.Linear(last_layer_dim_pi*timesteps*feature_dim,
last_layer_dim_pi*feature_dim), nn.ReLU(),
nn.Linear(last_layer_dim_pi*feature_dim, last_layer_dim_pi),
nn.Tanh())
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` """ return self.policy_net(features), self.value_net(features)def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor: return self.value_net(features)

Let’s see how our model performs by training it using a part of the data set and inferring it over the test data set.

model = PPO(CustomActorCriticPolicy, env, verbose=2,tensorboard_log='tb_logs', batch_size=256)model.learn(total_timesteps=1000)
The rewards of the RL model with the new custom MLP architecture.

That’s it and now we have designed a custom policy and value networks for our RL policy.

In fact, as long as we build a network that takes input as the 1D array (Num_stocks*Timesteps*Num_Features) and output the actions in the size (Num_stocks,) we should be good with any neural network we put in as a policy or value network so long as the learning process remains stable and bug-free.

In fact, in the next article, we will get funky as to how we can do it by adding recurrent layers to learn along the temporal dimension, adding attention mechanism in the hidden layers, try if we can add geometric learning by creating a fully connected graph with each node representing an asset.

You can find the code for this article on Github.

Disclaimer: I would like to assert that these articles are not to be misconstrued as investment advice. Most algorithmic trading systems lose money when deployed to production. This article series is for educational purposes only.

I would appreciate it if you could show some love and leave a star on the repo.

The next article in this series can be found here →

--

--

Akhilesh Gogikar

A Geek — I like many nerdy tropes — strategy games, anime, development. I can hold intriguing conversations — just not riveting enough to pay the bills! :P