Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Golang - concurrent SSH connections to multiple nodes

I have a fleet of servers that I'm trying to establish SSH connections to, and I'm spawning a new goroutine for every new SSH connection I have to establish. I then send the results of that connection (along with the error(s) (if any)) down a channel, and then read from the channel. This program sort of works, but it freezes in the end even though I close the channel.

This is what I have so far:

package main

import (
    "fmt"
    "net"
    "sync"

    "github.com/awslabs/aws-sdk-go/aws"
    "github.com/awslabs/aws-sdk-go/service/ec2"
)

// ConnectionResult container
type ConnectionResult struct {
    host    string
    message string
}

func main() {
    cnres := make(chan ConnectionResult)
    ec2svc := ec2.New(&aws.Config{Region: "us-east-1"})
    wg := sync.WaitGroup{}

    params := &ec2.DescribeInstancesInput{
        Filters: []*ec2.Filter{
            &ec2.Filter{
                Name: aws.String("instance-state-name"),
                Values: []*string{
                    aws.String("running"),
                },
            },
        },
    }

    resp, err := ec2svc.DescribeInstances(params)
    if err != nil {
        panic(err)
    }

    for _, res := range resp.Reservations {
        for _, inst := range res.Instances {
            for _, tag := range inst.Tags {
                if *tag.Key == "Name" {
                    host := *tag.Value
                    wg.Add(1)
                    go func(hostname string, cr chan ConnectionResult) {
                        defer wg.Done()
                        _, err := net.Dial("tcp", host+":22")
                        if err != nil {
                            cr <- ConnectionResult{host, "failed"}
                        } else {
                            cr <- ConnectionResult{host, "succeeded"}
                        }
                    }(host, cnres)
                }
            }
        }
    }

    for cr := range cnres {
        fmt.Println("Connection to " + cr.host + " " + cr.message)
    }

    close(cnres)

    defer wg.Wait()
}

What am I doing wrong? Is there a better way of doing concurrent SSH connections in Go?

like image 251
George K. Avatar asked Sep 27 '22 19:09

George K.


2 Answers

The code above is stuck in the range cnres for loop. As pointed out in the excellent 'Go by Example', range will only exit on a closed channel.

One way to address that difficulty, is to run the range cnres iteration in another goroutine. You could then wg.Wait(), and then close() the channel, as such:

...
go func() {
        for cr := range cnres {
                fmt.Println("Connection to " + cr.host + " " + cr.message)
        }   
}() 
wg.Wait()
close(cnres)

On a tangential note (independently of the code being stuck), I think the intention was to use hostname in the Dial() function, and subsequent channel writes, rather than host.

like image 121
Frederik Deweerdt Avatar answered Oct 06 '22 01:10

Frederik Deweerdt


Thanks to Frederik, I was able to get this running successfully:

package main

import (
    "fmt"
    "net"
    "sync"

    "github.com/awslabs/aws-sdk-go/aws"
    "github.com/awslabs/aws-sdk-go/service/ec2"
)

// ConnectionResult container
type ConnectionResult struct {
    host    string
    message string
}

func main() {
    cnres := make(chan ConnectionResult)
    ec2svc := ec2.New(&aws.Config{Region: "us-east-1"})
    wg := sync.WaitGroup{}

    params := &ec2.DescribeInstancesInput{
        Filters: []*ec2.Filter{
            &ec2.Filter{
                Name: aws.String("instance-state-name"),
                Values: []*string{
                    aws.String("running"),
                },
            },
        },
    }

    resp, err := ec2svc.DescribeInstances(params)
    if err != nil {
        panic(err)
    }

    for _, res := range resp.Reservations {
        for _, inst := range res.Instances {
            for _, tag := range inst.Tags {
                if *tag.Key == "Name" {
                    host := *tag.Value
                    publicdnsname := *inst.PublicDNSName
                    wg.Add(1)
                    go func(ec2name, cbname string, cr chan ConnectionResult) {
                        defer wg.Done()
                        _, err := net.Dial("tcp", ec2name+":22")
                        if err != nil {
                            cr <- ConnectionResult{cbname, "failed"}
                        } else {
                            cr <- ConnectionResult{cbname, "succeeded"}
                        }
                    }(publicdnsname, host, cnres)
                }
            }
        }
    }

    go func() {
        for cr := range cnres {
            fmt.Println("Connection to " + cr.host + " " + cr.message)
        }
    }()

    wg.Wait()
}
like image 35
George K. Avatar answered Oct 06 '22 01:10

George K.