Adding compile-time safety to the AWS SDK with syn’s Visit trait

Sam Van Overmeire
6 min readJul 2, 2024

--

Image by Bing Image Creator (prompt: “The Rust crab checking AWS SDK calls”)

Earlier this year, I wrote a blog post on how to use syn’s Fold trait to recursively transform a function, as an appendix of sorts to my book. Another trait that I still wanted to write about, is Visit. And as it happens, some code I wrote recently proved to be an excellent use case for that trait.

Let’s take a step back. Open source code is often born from frustration. A developer finds that there is no good or easy way to do something, and decides to take action. The origins of my tiny project are much the same. I was using the AWS Rust SDK to call DynamoDB and S3, and I was annoyed to see my Lambda crash — twice — with an error, forcing me to dive into the logs to find out what went wrong. In both cases, the root cause was my failure to pass on all the required properties to the client builder. If anything is missing from those builders, the AWS SDK errors at runtime.

That was frustrating because most of the time Rust code that compiles, works. Plus, the language offers some really great tooling for writing safe builders, like Typestate — a subject I briefly explore in my book. Now, to be fair to AWS, its SDKs are generated from specifications using a tool called Smithy, and with that extra layer of complexity, it’s probably harder to generate this type of safety. Even so, I was hoping for failure at compile time. That would give me a much shorter feedback loop than waiting for my Lambdas to fail at runtime.

So, I wrote an attribute macro to do just that, and called it required_props. If you annotate a function or method with it, Rust will fail at compile time when it sees an SDK call that is missing a required property.

In other words, this code (with the queue url missing):

use aws_sdk_compile_checks_macro::required_props;
use aws_sdk_sqs::config::BehaviorVersion;

#[required_props]
async fn example() -> Result<(), String> {
let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
let sqs_client = aws_sdk_sqs::Client::new(&aws_config);
let _ = sqs_client.send_message()
// missing queue url
.message_body("some message")
.send()
.await;
Ok(())
}

Will give back this message:

error: method `send_message` (from sqs) is missing required argument(s): `queue_url`
--> ...
|
191 | let _ = sqs_client.send_message()
| ^^^^^^^^^^^^
...

A large part of the repo, some 700 lines of code excluding tests, is dedicated to heuristics that help us determine the required properties. How do we know that this is an (AWS) SDK call? And if it is an SDK call, what specific client is used, and what method is called? These heuristics are not perfect, but in my experiments they work well enough to add value, an additional layer of safety. And if/when the macro is wrong, the developer can simply remove it, or adapt its attributes to make the heuristics more accurate (as described in the project’s README).

Now heuristics are all fine and dandy, but to make them work, we need some way of analyzing the annotated functions and retrieving everything the macro needs to judge the call. Doing this for the function signature is relatively easy: we check the function arguments to see if an AWS client is passed in. Its name and type might reveal what SDK is being used. The function body is hard, however, even though the goal is simple: we want to gather up all ExprMethodCall (e.g. s3_client.get_bucket_policy()) expressions and their method calls.

The big issue is that we have no idea where those calls will occur. Maybe we are using the client in an ‘if expression’, so best to explore those. And we also have to explore the other 38 — at the time of writing — expression variants in syn. And we have to do all of it recursively. Doable? Yes. But it requires a lot of code, and quite a bit of it is boilerplate.

From the introduction, you may have already gathered how we are going to tackle those issues. What we need is the recursive magic that Fold offered for my other blog post problem. Except this time, there is no need for ownership because we don’t want to manipulate existing code, we just want to check it for issues. So Visit, which takes a borrowed syntax tree, is good enough this time around. Visit offers more than a hundred methods with default implementations that do only one thing: explore the substructure of the node that was passed along as an input. Here, for example, is the code for exploring a function (ItemFn):

pub trait Visit<'ast> {
fn visit_item_fn(&mut self, i: &'ast crate::ItemFn) {
visit_item_fn(self, i); // the default implementation...
}

// other methods
}

pub fn visit_item_fn<'ast, V>(v: &mut V, node: &'ast crate::ItemFn) where V: Visit<'ast> + ?Sized {
// ... calls a function that visits the substructure of the function
for it in &node.attrs {
v.visit_attribute(it);
}
v.visit_visibility(&node.vis);
v.visit_signature(&node.sig);
v.visit_block(&*node.block);
}

Thanks to these default implementations, using the trait is fairly easy: you overwrite methods that you are interested in, retrieving what you need from the input. Finally, you call the default implementation. Simple but ingenious. But do remember to call that default! Say you’re interested in a function’s body and use visit_item_fn to retrieve it, while visit_signature is used for analyzing the function’s signature:

impl<'ast> Visit<'ast> for MyVisitor {

// override the default method impl because we're interested in the function body
fn visit_item_fn(&mut self, i: &'ast ItemFn) {
println!("Visiting item fn...");
// we forgot to call the default function!
}

// override the default method impl because we're interested in the signature
fn visit_signature(&mut self, i: &'ast Signature) {
println!("Visiting signature...");
// here we do remember to call the default
visit::visit_signature(self, i);
}
}

Calling this code with my_visitor.visit_item_fn(fun_as_token_stream); will only print the first message (“Visiting item fn…”), not the second one. Visit never drilled down to the signature of the function, leaving part of the tree unexplored. All because we forgot to call the default.

In our case, to explore every single method call within the function, which was the goal we set a few paragraphs earlier, we should override visit_expr_method_call.

impl<'ast> Visit<'ast> for MethodVisitor {
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
let method_call = node.method.clone(); // the easy bit: get the method call

// but the receiver might also be interesting for heuristics, so try to capture it as well
match node.receiver.as_ref() {
Expr::Path(p) => {
let segments = p.path.segments.clone();
let receiver = segments.into_iter().map(|s| s.ident).collect::<Vec<Ident>>().pop();

// push the method call and receiver onto our vector
// (same thing happens in the matches below)
self.method_calls.push(MethodCallWithReceiver { method_call, receiver });
}
Expr::Field(f) => {
// call on a field, e.g. `a_struct.client` or `self.client`
match &f.member {
Member::Named(field_name) => {
let receiver = Some(field_name.clone());
self.method_calls.push(MethodCallWithReceiver { method_call, receiver });
}
Member::Unnamed(_) => {
self.method_calls.push(MethodCallWithReceiver {
method_call,
receiver: None,
})
}
}
}
_ => self.method_calls.push(MethodCallWithReceiver {
method_call,
receiver: None,
}),
}

visit::visit_expr_method_call(self, node);
}
}

As you can see, finding method calls like get_bucket_policy() is fairly easy, because those are contained within the method field. Meaning we can just clone and push them onto a vector. For our heuristics, we’re also interested in the ‘receiver’, i.e. the one calling these methods (e.g. s3_client for s3_client.get_bucket_policy()) because the name might reveal more about the client. For the very same reason, we visit local let bindings in case those tell us anything useful, by overriding the visit_local method (not shown here).

Now, we still have to invoke the visitor. But that’s not hard once we have the input as a function (syn::ItemFn):

#[proc_macro_attribute] // macro entrypoint
pub fn required_props(attrs: TokenStream, input: TokenStream) -> TokenStream {
let attributes: Attributes = parse_macro_input!(attrs);
let item: ItemFn = parse_macro_input!(input);
// retrieve a map of all the required properties of the AWS SDK clients
let required_props = create_required_props_map();

// check the attributes

// call the visitor with the input (a function as a token stream)
let visitor = visitor::MethodVisitor::new(&item, required_props);
// check for any improper calls
let improper = visitor.find_improper_usages(sdks);

// handle improper usages (if any) and return the input
}

With about 40 lines of code, we have elegantly gathered up every method call from the annotated function, with another 30 for the let bindings. The Visit trait turned parsing the function — almost — into a breeze, moving the hard work to the ongoing work of figuring out good heuristics!

--

--

Sam Van Overmeire

Medium keeps bugging me about filling in a Bio. Maybe this will make those popups go away.