[Shader]: kNearest Neighbors

Videep
Analytics Vidhya
Published in
4 min readOct 31, 2019

--

Through this article, I would like to share the recent study that I had done over finding the nearest neighbors.

Problem Statement:

Given a 2d grid, with an input as 2D coordinates X and Y, find nearest or adjacent grid-cells to it. This can range from 0 to n.

Input:

To start with the solution, we need the following inputs:

Grid Size -> X*Y or Rows*Columns

Grid-Cell -> The X and Y coordinates of a grid-cell / point to which we need to find the neighbors for

Step or Range -> The range till where we need to find the neighbors

Solution:

The solution is given as a shader in Unity (on squared plane):

Also as ShaderToy (640x360) image:

The Basics:

First we start off with creating a 2D grid. I would be using ShaderToy for the web demonstration. The following snippet will help you get started off with a checkerboard pattern:

vec2 _GridSize = vec2(16.0, 16.0);void mainImage( out vec4 fragColor, in vec2 fragCoord )
{
fragCoord.x *= _GridSize.x;
fragCoord.y *= _GridSize.y;

// Normalized pixel coordinates (from 0 to 1)
vec2 uv = fragCoord/iResolution.xy;
// Time varying pixel color
vec3 col;
vec2 floorVal = vec2(floor(uv.x), floor(uv.y))/2.0;
float checkerVal = fract(floorVal.x+floorVal.y)*2.0;
col.x = checkerVal;
col.y = checkerVal;
col.z = checkerVal;

// Output to screen
fragColor = vec4(col,1.0);
}

The Nearest Neighbors:

Once you have the 2D grid in place, we need to take the other inputs and start writing the algorithm.

While computing the 2D grid, we know that it is a simple data collection of number of rows * number of columns. Hence when we take a position 2D grid, it would have the 2 coordinates X and Y as Row and Column.

Thinking in these terms makes it really simpler to compute the nearest neighbors. All we need to do is change the row and column number to get the adjacent grid-cells. We change this by unit(1) for the moment to get the following coordinates:

StartPos : [X, Y]

Neighbors: [X-1, Y], [X+1, Y], [X, Y-1] and [X, Y+1]; can go further with [X+1, Y+1], [X-1, Y-1], [X-1, Y+1] and [X+1, Y-1] …

Now it makes more sense to use a for loop to iterate from the minValue to maxValue to cover the Range of neighbors we want to select. Following code snippet will help you find if the position you are iterating over is near to the input or not:

float isNearest(vec2 pt)
{
bool val = false;
for (float x = (-0.5*float(_Step)); x <= (0.5*float(_Step)); x+=0.5)
{
for (float y = (-0.5*float(_Step)); y <= (0.5*float(_Step)); y+=0.5)
{
if (pt.x == _InputPos.x+x && pt.y == _InputPos.y+y)
{
if(_Step == 0)
return 1.0/(float(_Step)+1.0);
else
return float(_Step)/(float(_Step) + abs(x)+abs(y));
}
}
}
return 0.0;
}

Examples:

Grid : 16x16; _InputPos : [2.0,4.0]; Step : 0

Grid : 16x16; _InputPos : [2.0,4.0]; Step : 1

Grid : 16x16; _InputPos : [2.0,4.0]; Step : 2

Grid : 16x16; _InputPos : [2.0,4.0]; Step : 3

The Source Code:

ShaderToy : https://www.shadertoy.com/view/tscXRf

Unity Shader:

Shader "Unlit/CheckerNeighbours"
{
Properties
{
_Density("Density", Range(2,50)) = 30
_PosX("X", Float) = 1
_PosY("Y", Float) = 1
_Step("Step", Int) = 1
}
SubShader
{
Tags { "RenderType"="Opaque" }
LOD 100

Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "UnityCG.cginc"

struct v2f
{
float2 uv : TEXCOORD0;
float4 vertex : SV_POSITION;
};

float _Density;
float _PosX;
float _PosY;
int _Step;

v2f vert(float4 pos : POSITION, float2 uv : TEXCOORD0)
{
v2f o;
o.vertex = UnityObjectToClipPos(pos);
o.uv = uv * _Density;
return o;
}

bool isOnMap(float x, float y, float2 pt)
{
return x == pt.x && y == pt.y;
}

float isNearest(float2 pt)
{
bool val = false;
for (float x = (-0.5*_Step); x <= (0.5*_Step); x+=0.5)
{
for (float y = (-0.5*_Step); y <= (0.5*_Step); y+=0.5)
{
if (pt.x == _PosX+x && pt.y == _PosY+y)
{
if(_Step == 0)
return 1.0/(_Step+1.0);
else
return _Step/(_Step + abs(x)+abs(y));
}
}
}
return 0;
}

fixed4 frag(v2f i) : SV_Target
{
float rows = _Density;
float cols = _Density;

float2 c = i.uv;
c = floor(c) / 2;
float checker = frac(c.x + c.y) * 2;
if (isNearest(c))
{
return float4(0.0, isNearest(c), 0.0, 1.0);
}
return checker;
}
ENDCG
}
}
}

--

--