Sharding PostgreSQL with Citus and Golang

Atharv Bhadange
6 min readDec 12, 2023

--

In the world of data, scaling vertically has its limits. As we tread the path of managing vast amounts of information, the concept of sharding emerges as a powerful ally. Sharding not only distributes the data load but also unlocks the potential for seamless scalability. In the realm of micro-services, when dealing with log data, sharding becomes an indispensable strategy. In this guide, we embark on a journey to implement database sharding using the Citus extension for PostgreSQL, complemented by a Golang backend powered by the GoFiber framework.

Before diving into the implementation, ensure you have a solid grasp of:

  1. Golang and the GoFiber framework.
  2. PostgreSQL fundamentals.
  3. Docker container basics.

Step 1: Setting the Foundation with GoFiber

  1. Initializing the project:
go mod init github.com/<your-github>/log-ingestor
# folder structure
log-ingestor/
|-- contoller/base.go
|-- service/base.go
|-- routes/base.go
|-- db/db.go
|-- model/log.go
|-- script/call_script.py
|-- go.mod
|-- go.sum
|-- main.go

2. Adding GoFiber package:

go get github.com/gofiber/fiber/v2

3. Initialising the Fiber app and setting up default routes:

// main.go
package main

import (
"github.com/atharv-bhadange/log-ingestor/routes"
"github.com/gofiber/fiber/v2"
)

func main() {
app := fiber.New()
routes.Setup(app)
app.Listen(":3000")
}

Under routes/base.go, add the Setup function used in the above code

// routes/base.go
package routes

import (
"github.com/atharv-bhadange/log-ingestor/controller"
"github.com/gofiber/fiber/v2"
)

func Setup(app *fiber.App) {
app.Get("/", controller.HealthCheck)
}

Add the HealthCheck function to controller/base.go

// controller/base.go
package controller

import (
"github.com/gofiber/fiber/v2"
)

func HealthCheck(c *fiber.Ctx) error {
return c.Status(200).JSON(model.Response{
Status: 200,
Message: "Server is up and running",
Data: nil,
})
}

Build and run the server to check for any errors so far

go build -o main && ./main

Step 2: Integrating CitusData for Distributed Database Magic:

  1. Installing and configuring CitusData:
docker run -d --name coordinator -p 5432:5432 -e POSTGRES_PASSWORD=your_password citusdata/citus:12.1

2. Connecting to the database using GORM as our ORM:

# add gorm package
go get -u gorm.io/gorm
go get gorm.io/driver/postgres

Add a function under db/db.go to connect to the database

package db

import (
"log"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

var (
Db *gorm.DB
)

func Connect() error {
dsn := "host=localhost user=postgres password=your_password port=5432 sslmode=disable"
var err error
Db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return err
}
log.Println("Database connected")
return nil
}

3. Defining the structure for log messages.

{
"level": "error",
"message": "Failed to connect to DB",
"resourceId": "server-1234",
"timestamp": "2023-09-15T08:00:00Z",
"traceId": "abc-xyz-123",
"spanId": "span-456",
"commit": "5e5342f",
"metadata": "{'source': 'client'}"
}

Write a corresponding struct for it under model/log.go

package model

import "time"

type Log struct {
ID uint `json:"id" gorm:"column:id"`
Level string `json:"level" gorm:"column:level"`
Message string `json:"message" gorm:"column:message"`
ResourceId string `json:"resourceId" gorm:"column:resourceId"`
Timestamp time.Time `json:"timestamp" gorm:"column:timestamp"`
TraceId string `json:"traceId" gorm:"column:traceId"`
SpanId string `json:"spanId" gorm:"column:spanId"`
Commit string `json:"commit" gorm:"column:commit"`
Metadata string `json:"metadata" gorm:"column:metadata"`
}

type Response struct {
Status int `json:"status"`
Message string `json:"message"`
Data interface{} `json:"data"`
}

Note: ‘id’ is kept simple uint which will be used to shard the database and the Response struct is used to structure the response from the endpoints.

Step 3: Migrating to Database:

  1. Creating and migrating the Log table:

Add these lines after connecting to the database:

err = Db.AutoMigrate(&model.Log{})
if err != nil {
return err
}

log.Println("Database migrated")

Now, when you start the server, we are connected to the database and our Log table is migrated. We can verify this by connecting to the coordinator container and having a look at the tables present in it.

# to connect
docker exec -it coordinator psql -U postgres

# to check tables
\dt

Step 4: Crafting Endpoints and Testing:

  1. Implementing CRUD endpoints:
// routes/base.go
app.Post("/log", controller.AddLog)

app.Get("/log", controller.GetLog)

app.Get("/log/:id", controller.GetLogById)
// controller/base.go
func AddLog(c *fiber.Ctx) error {
var log model.Log

if err := c.BodyParser(&log); err != nil {
return c.Status(400).JSON(model.Response{
Status: 400,
Message: "Bad request",
Data: nil,
})
}

log, err := service.AddLog(log)
if err != nil {
return c.Status(500).JSON(model.Response{
Status: 500,
Message: "Internal server error",
Data: nil,
})
}

return c.Status(200).JSON(model.Response{
Status: 200,
Message: "Log added successfully",
Data: log.Message,
})
}


func GetLog(c *fiber.Ctx) error {

logs, err := service.GetLog()
if err != nil {
return c.Status(500).JSON(model.Response{
Status: 500,
Message: "Internal server error",
Data: nil,
})
}

return c.Status(200).JSON(model.Response{
Status: 200,
Message: "Log fetched successfully",
Data: logs,
})
}


func GetLogById(c *fiber.Ctx) error {

id_str := c.Params("id")
id, err := strconv.Atoi(id_str)
if err != nil {
return c.Status(400).JSON(model.Response{
Status: 400,
Message: "Bad request",
Data: nil,
})
}

log, err := service.GetLogById(id)
if err != nil {
return c.Status(500).JSON(model.Response{
Status: 500,
Message: "Internal server error",
Data: nil,
})
}

return c.Status(200).JSON(model.Response{
Status: 200,
Message: "Log fetched successfully",
Data: log,
})
}
// service/base.go
package service

import (
"github.com/atharv-bhadange/log-ingestor/db"
"github.com/atharv-bhadange/log-ingestor/model"
)

func AddLog(log model.Log) (model.Log, error) {
tx := db.Db.Create(&log)
if tx.Error != nil {
return model.Log{}, tx.Error
}
return log, nil
}

func GetLog() ([]model.Log, error) {
var logs []model.Log
tx := db.Db.Order("id").Find(&logs)
if tx.Error != nil {
return nil, tx.Error
}
return logs, nil
}

func GetLogById(id int) (model.Log, error) {
var log model.Log
tx := db.Db.Where("id = ?", id).First(&log)
if tx.Error != nil {
return model.Log{}, tx.Error
}
return log, nil
}

2. Testing the server and data manipulation:

Under script/call_script.py, I have added a Python script to call the post endpoint and add dummy data. Tune the number of requests and sleep time as you need

import requests
import json
import time
from datetime import datetime, timedelta
import random

# Define variations for each field
levels = ["info", "warning", "error", "critical", "debug"]
messages = ["Connection established", "Failed to connect to DB", "Resource not found", "User not authenticated", "User not authorized", "Invalid request", "Invalid response", "Request timed out"]
resource_ids = ["server-1234", "server-5678", "server-91011", "server-121314", "server-151617", "server-181920"]
trace_ids = ["abc-xyz-123", "def-uvw-456", "ghi-rst-789", "jkl-mno-101112", "pqr-stu-131415", "vwx-efg-161718"]
span_ids = ["span-123", "span-456", "span-789", "span-101112", "span-131415", "span-161718"]
commits = ["5e5342f", "a1b2c3d", "e4f5g6h", "i7j8k9l", "m0n1o2p", "q3r4s5t"]
metadata_values = ["{'key': 'value'}", "{'type': 'log'}", "{'source': 'app'}", "{'source': 'server'}", "{'source': 'client'}", "{'source': 'database'}"]

# Generate timestamps for the past 6 months
end_time = datetime.utcnow()
start_time = end_time - timedelta(days=180)

# Function to generate a random timestamp within the specified range
def generate_random_timestamp():
return start_time + timedelta(seconds=random.randint(0, int((end_time - start_time).total_seconds())))

# API endpoint
api_url = "http://localhost:3000/log"

# Number of requests
num_requests = 500

for _ in range(num_requests):
data = {
"level": random.choice(levels),
"message": random.choice(messages),
"resourceId": random.choice(resource_ids),
"timestamp": generate_random_timestamp().strftime("%Y-%m-%dT%H:%M:%SZ"),
"traceId": random.choice(trace_ids),
"spanId": random.choice(span_ids),
"commit": random.choice(commits),
"metadata": random.choice(metadata_values)
}

# Convert data to JSON
json_data = json.dumps(data)

# Make a POST request to the API
response = requests.post(api_url, data=json_data, headers={"Content-Type": "application/json"})

# Print the response status code and sleep for 10ms
print(f"Status Code: {response.status_code}")
time.sleep(0.01)

Step 5: Distributing Data Across Multiple Shards:

Under the Postgres command line execute

-- Convert the table to a distributed table, partition my 'id' column
SELECT create_distributed_table('logs', 'id');

This will automatically distribute the existing table’s data into multiple shards Create an index over the id column to create a distributed index

CREATE INDEX log_id ON logs(id);

To check the actual shard metadata try running

SELECT * FROM citus_shards;

You can see 32 shards created for your table

To check the placement of shards over the nodes:

SELECT * from pg_dist_placement;

As we are running on a single node, all the shards are in groupid 0

Step 6: Scaling Beyond a Single Node:

As we wrap up this chapter, anticipate the next instalment where we venture into setting up a multi-node cluster using AWS EC2 instances. Stay tuned for more insights, and if you enjoyed this journey, leave a 👏 to show your appreciation!

Find full implementation here: https://github.com/atharv-bhadange/log-ingestor

PS: Your corrections/improvements are welcome!

--

--