Commit 0b82c298 authored by Boris Mühmer's avatar Boris Mühmer
Browse files

completed basic expect



Signed-off-by: default avatarBoris Mühmer <boris@muehmer.de>
parent c36520ca
Loading
Loading
Loading
Loading

get.go

0 → 100644
+270 −0
Original line number Diff line number Diff line
package get

import (
	"fmt"
	"io"
	"regexp"
	"sync"

	"repositories.muehmer.net/bsmrgo/get/mb"
)

type Channel uint

const (
	undefined Channel = iota
	In
	Out
	Err
)

const (
	textChannelIn          = "In"
	textChannelOut         = "Out"
	textChannelErr         = "Err"
	textChannelUndefined   = "Undefined"
	formatChannelUndefined = "unknown channel %d"
)

func (c Channel) String() string {
	switch c {
	case In:
		return textChannelIn
	case Out:
		return textChannelOut
	case Err:
		return textChannelErr
	default:
		return fmt.Sprintf(formatChannelUndefined, uint(c))
	}
}

type Expect struct {
	stdin           io.Reader
	stdout          io.Writer
	stderr          io.Writer
	matchStream     chan ChannelPatternValue
	terminateStream chan bool
	finishedStream  chan bool
}

type PatternValue struct {
	pattern *regexp.Regexp
	value   string
}

type ChannelPatternValue struct {
	c Channel
	m PatternValue
}

func New() (*Expect, error) {
	x := &Expect{
		terminateStream: make(chan bool),
		finishedStream:  make(chan bool),
	}

	startedChannel := make(chan bool)
	defer close(startedChannel)
	go x.worker(startedChannel)
	<-startedChannel

	return x, nil
}

func asyncSendBool(c chan<- bool, b bool) {
	go func() {
		c <- b
	}()
}

func (x *Expect) Close() error {
	asyncSendBool(x.terminateStream, true)
	<-x.finishedStream
	return nil
}

func (x *Expect) Stdin() io.Reader {
	return x.stdin
}

func (x *Expect) Stdout() io.Writer {
	return x.stdout
}

func (x *Expect) Stderr() io.Writer {
	return x.stderr
}

func (x *Expect) Match(c Channel, pattern, text string) error {
	match, err := regexp.Compile(pattern)
	if err != nil {
		return err
	}
	x.matchStream <- ChannelPatternValue{
		c: c,
		m: PatternValue{pattern: match, value: text},
	}
	return nil
}

func (x *Expect) worker(started chan<- bool) {
	stdinReader, stdinWriter := io.Pipe()
	stdout := &mb.MutexBuffer{}
	stderr := &mb.MutexBuffer{}

	x.stdin = stdinReader
	x.stdout = stdout
	x.stderr = stderr

	x.matchStream = make(chan ChannelPatternValue)

	matchOutStream := make(chan PatternValue)
	startedOutStream := make(chan bool)
	terminateOutStream := make(chan bool)
	terminatedOutStream := make(chan bool)
	terminatedOut := false

	matchErrStream := make(chan PatternValue)
	startedErrStream := make(chan bool)
	terminateErrStream := make(chan bool)
	terminatedErrStream := make(chan bool)
	terminatedErr := false

	stringStream := make(chan string)
	defer close(stringStream)
	checkStream := make(chan bool)
	defer close(checkStream)

	var wg sync.WaitGroup

	for _, pa := range []struct {
		in               io.Reader
		startedStream    chan<- bool
		matchStream      <-chan PatternValue
		stringStream     chan<- string
		terminateStream  <-chan bool
		terminatedStream chan<- bool
	}{
		{stdout, startedOutStream, matchOutStream, stringStream, terminateOutStream, terminatedOutStream},
		{stderr, startedErrStream, matchErrStream, stringStream, terminateErrStream, terminatedErrStream},
	} {
		wg.Add(1)
		go func(
			in io.Reader,
			startedStream chan<- bool,
			matchStream <-chan PatternValue,
			stringStream chan<- string,
			terminateStream <-chan bool,
			terminatedStream chan<- bool,
		) {
			defer wg.Done()
			process(in, startedStream, matchStream, stringStream, terminateStream, terminatedStream)
		}(pa.in, pa.startedStream, pa.matchStream, pa.stringStream, pa.terminateStream, pa.terminatedStream)
	}

	<-startedOutStream
	<-startedErrStream

	started <- true

	for {
		select {
		case match, valid := <-x.matchStream:
			switch {
			case !valid:
				continue
			case match.c == Out:
				matchOutStream <- match.m
			case match.c == Err:
				matchErrStream <- match.m
			}
		case s, valid := <-stringStream:
			if !valid {
				continue
			}
			fmt.Fprintln(stdinWriter, s)
		case terminate, valid := <-x.terminateStream:
			switch {
			case !valid || !terminate:
				continue
			default:
				asyncSendBool(terminateOutStream, true)
				asyncSendBool(terminateErrStream, true)
			}
		case terminated, valid := <-terminatedOutStream:
			switch {
			case !valid || !terminated:
				continue
			default:
				terminatedOut = true
				asyncSendBool(checkStream, true)
			}
		case terminated, valid := <-terminatedErrStream:
			switch {
			case !valid || !terminated:
				continue
			default:
				terminatedErr = true
				asyncSendBool(checkStream, true)
			}
		case check, valid := <-checkStream:
			switch {
			case !valid || !check || !terminatedOut || !terminatedErr:
				continue
			default:
				wg.Wait()
				x.finishedStream <- true
			}
		}
	}
}

func process(
	in io.Reader,
	startedStream chan<- bool,
	matchStream <-chan PatternValue,
	stringStream chan<- string,
	terminateStream <-chan bool,
	terminatedStream chan<- bool,
) {
	defer asyncSendBool(terminatedStream, true)
	matches := []PatternValue{}

	startedStream <- true

	for {
		select {
		case match, valid := <-matchStream:
			if !valid {
				continue
			}
			matches = append(matches, match)
		case terminate, valid := <-terminateStream:
			switch {
			case !valid || !terminate:
				continue
			default:
				return
			}
		default:
			bytes, err := io.ReadAll(in)
			if err != nil {
				continue
			}
			if len(bytes) == 0 {
				continue
			}
			for _, m := range matches {
				s := string(bytes)
				if !m.pattern.MatchString(s) {
					continue
				}
				go func(s string) {
					stringStream <- s
				}(m.value)
				break
			}
		}
	}
}

get_test.go

0 → 100644
+65 −0
Original line number Diff line number Diff line
package get

import (
	"fmt"
	"testing"

	"repositories.muehmer.net/bsmrgo/get/mock/passwd"
)

func TestChannelStrings(t *testing.T) {
	uc := 99
	for _, c := range []struct {
		c Channel
		s string
	}{
		{c: In, s: textChannelIn},
		{c: Out, s: textChannelOut},
		{c: Err, s: textChannelErr},
		{c: Channel(uc), s: fmt.Sprintf(formatChannelUndefined, uc)},
	} {
		if s := c.c.String(); s != c.s {
			t.Errorf("String() is %q, expected %q", s, c.s)
		}
	}
}

const (
	lp1User = "user1"
	lp1Pass = "pass1"
	lp2User = "user2"
	lp2Pass = "pass2"
	lpNPass = "#S3cr3T!"
	lpXPass = "wr0ng"
)

func TestPasswdChange(t *testing.T) {
	x, err := New()
	if err != nil {
		t.Fatalf("get.New() failed with: %s", err)
	}
	defer x.Close()

	for _, cpv := range []struct {
		c Channel
		p string
		v string
	}{
		{c: Err, p: `^[cC]urrent password`, v: lp2Pass},
		{c: Err, p: `^[nN]ew password`, v: lpNPass},
		{c: Err, p: `^[rR]etype new password`, v: lpNPass},
	} {
		if err := x.Match(cpv.c, cpv.p, cpv.v); err != nil {
			t.Errorf("x.Match(%s,%q,%q) failed with: %s", cpv.c, cpv.p, cpv.v, err)
		}
	}

	p, err := passwd.New(x.Stdin(), x.Stdout(), x.Stderr(), lp2User)
	if err != nil {
		t.Fatalf("passwd.New() failed with: %s", err)
	}

	if err := p.Run(); err != nil {
		t.Fatalf("p.Run() failed with: %s", err)
	}
}

go.mod

0 → 100644
+3 −0
Original line number Diff line number Diff line
module repositories.muehmer.net/bsmrgo/get

go 1.17

mb/mb.go

0 → 100644
+23 −0
Original line number Diff line number Diff line
package mb

import (
	"bytes"
	"sync"
)

type MutexBuffer struct {
	buffer bytes.Buffer
	mutex  sync.Mutex
}

func (mb *MutexBuffer) Read(p []byte) (n int, err error) {
	mb.mutex.Lock()
	defer mb.mutex.Unlock()
	return mb.buffer.Read(p)
}

func (mb *MutexBuffer) Write(p []byte) (n int, err error) {
	mb.mutex.Lock()
	defer mb.mutex.Unlock()
	return mb.buffer.Write(p)
}

mb/mb_test.go

0 → 100644
+28 −0
Original line number Diff line number Diff line
package mb

import (
	"fmt"
	"io"
	"testing"
)

func TestSkip(t *testing.T) {
	var mb MutexBuffer

	ri := io.Reader(&mb)
	wi := io.Writer(&mb)

	msg := "Hello, World!"
	if _, err := fmt.Fprintf(wi, "%q", msg); err != nil {
		t.Errorf("fmt.Fprintln() failed with: %s", err)
	}

	var res string
	if _, err := fmt.Fscanf(ri, "%q", &res); err != nil {
		t.Errorf("fmt.Fscanf() failed with: %s", err)
	}

	if res != msg {
		t.Errorf("res is %q, expected %q", res, msg)
	}
}
Loading