hci: multiple connections

Signed-off-by: deadprogram <ron@hybridgroup.com>
This commit is contained in:
deadprogram 2024-01-26 21:00:49 +01:00 committed by BCG
parent 553633e56a
commit 3e90718eb8
2 changed files with 163 additions and 83 deletions

View file

@ -67,12 +67,15 @@ const (
) )
var ( var (
ErrATTTimeout = errors.New("bluetooth: ATT timeout") ErrATTTimeout = errors.New("bluetooth: ATT timeout")
ErrATTUnknownEvent = errors.New("bluetooth: ATT unknown event") ErrATTUnknownEvent = errors.New("bluetooth: ATT unknown event")
ErrATTUnknown = errors.New("bluetooth: ATT unknown error") ErrATTUnknown = errors.New("bluetooth: ATT unknown error")
ErrATTOp = errors.New("bluetooth: ATT OP error") ErrATTOp = errors.New("bluetooth: ATT OP error")
ErrATTUnknownConnection = errors.New("bluetooth: ATT unknown connection")
) )
const defaultTimeoutSeconds = 10
type rawService struct { type rawService struct {
startHandle uint16 startHandle uint16
endHandle uint16 endHandle uint16
@ -249,9 +252,7 @@ func (a *rawAttribute) length() int {
} }
} }
type att struct { type connectData struct {
hci *hci
busy sync.Mutex
responded bool responded bool
errored bool errored bool
lastErrorOpcode uint8 lastErrorOpcode uint8
@ -263,26 +264,34 @@ type att struct {
characteristics []rawCharacteristic characteristics []rawCharacteristic
descriptors []rawDescriptor descriptors []rawDescriptor
value []byte value []byte
notifications chan rawNotification }
connections []uint16 type att struct {
lastHandle uint16 hci *hci
attributes []rawAttribute busy sync.Mutex
localServices []rawService mtu uint16
maxMTU uint16
notifications chan rawNotification
connections []uint16
connectionsData map[uint16]*connectData
lastHandle uint16
localServices []rawService
localCharacteristics []rawCharacteristic
attributes []rawAttribute
} }
func newATT(hci *hci) *att { func newATT(hci *hci) *att {
return &att{ return &att{
hci: hci, hci: hci,
services: []rawService{}, localCharacteristics: []rawCharacteristic{},
characteristics: []rawCharacteristic{}, notifications: make(chan rawNotification, 32),
value: []byte{}, connections: []uint16{},
notifications: make(chan rawNotification, 32), connectionsData: make(map[uint16]*connectData),
connections: []uint16{}, lastHandle: 0x0001,
lastHandle: 0x0001, attributes: []rawAttribute{},
attributes: []rawAttribute{}, localServices: []rawService{},
localServices: []rawService{}, maxMTU: 248,
maxMTU: 248,
} }
} }
@ -304,7 +313,7 @@ func (a *att) readByGroupReq(connectionHandle, startHandle, endHandle uint16, uu
return err return err
} }
return a.waitUntilResponse() return a.waitUntilResponse(connectionHandle)
} }
func (a *att) readByTypeReq(connectionHandle, startHandle, endHandle uint16, typ uint16) error { func (a *att) readByTypeReq(connectionHandle, startHandle, endHandle uint16, typ uint16) error {
@ -325,7 +334,7 @@ func (a *att) readByTypeReq(connectionHandle, startHandle, endHandle uint16, typ
return err return err
} }
return a.waitUntilResponse() return a.waitUntilResponse(connectionHandle)
} }
func (a *att) findInfoReq(connectionHandle, startHandle, endHandle uint16) error { func (a *att) findInfoReq(connectionHandle, startHandle, endHandle uint16) error {
@ -345,7 +354,7 @@ func (a *att) findInfoReq(connectionHandle, startHandle, endHandle uint16) error
return err return err
} }
return a.waitUntilResponse() return a.waitUntilResponse(connectionHandle)
} }
func (a *att) readReq(connectionHandle, valueHandle uint16) error { func (a *att) readReq(connectionHandle, valueHandle uint16) error {
@ -364,7 +373,7 @@ func (a *att) readReq(connectionHandle, valueHandle uint16) error {
return err return err
} }
return a.waitUntilResponse() return a.waitUntilResponse(connectionHandle)
} }
func (a *att) writeCmd(connectionHandle, valueHandle uint16, data []byte) error { func (a *att) writeCmd(connectionHandle, valueHandle uint16, data []byte) error {
@ -402,7 +411,7 @@ func (a *att) writeReq(connectionHandle, valueHandle uint16, data []byte) error
return err return err
} }
return a.waitUntilResponse() return a.waitUntilResponse(connectionHandle)
} }
func (a *att) mtuReq(connectionHandle uint16) error { func (a *att) mtuReq(connectionHandle uint16) error {
@ -410,18 +419,23 @@ func (a *att) mtuReq(connectionHandle uint16) error {
println("att.mtuReq:", connectionHandle) println("att.mtuReq:", connectionHandle)
} }
cd, err := a.findConnectionData(connectionHandle)
if err != nil {
return err
}
a.busy.Lock() a.busy.Lock()
defer a.busy.Unlock() defer a.busy.Unlock()
var b [3]byte var b [3]byte
b[0] = attOpMTUReq b[0] = attOpMTUReq
binary.LittleEndian.PutUint16(b[1:], a.mtu) binary.LittleEndian.PutUint16(b[1:], cd.mtu)
if err := a.sendReq(connectionHandle, b[:]); err != nil { if err := a.sendReq(connectionHandle, b[:]); err != nil {
return err return err
} }
return a.waitUntilResponse() return a.waitUntilResponse(connectionHandle)
} }
func (a *att) setMaxMTU(mtu uint16) error { func (a *att) setMaxMTU(mtu uint16) error {
@ -431,7 +445,9 @@ func (a *att) setMaxMTU(mtu uint16) error {
} }
func (a *att) sendReq(handle uint16, data []byte) error { func (a *att) sendReq(handle uint16, data []byte) error {
a.clearResponse() if err := a.clearResponse(handle); err != nil {
return err
}
if debug { if debug {
println("att.sendReq:", handle, "data:", hex.EncodeToString(data)) println("att.sendReq:", handle, "data:", hex.EncodeToString(data))
@ -470,7 +486,9 @@ func (a *att) sendNotification(handle uint16, data []byte) error {
} }
func (a *att) sendError(handle uint16, opcode uint8, hdl uint16, code uint8) error { func (a *att) sendError(handle uint16, opcode uint8, hdl uint16, code uint8) error {
a.clearResponse() if err := a.clearResponse(handle); err != nil {
return err
}
if debug { if debug {
println("att.sendError:", handle, "data:", opcode, hdl, code) println("att.sendError:", handle, "data:", opcode, hdl, code)
@ -494,15 +512,20 @@ func (a *att) handleData(handle uint16, buf []byte) error {
println("att.handleData:", handle, "data:", hex.EncodeToString(buf)) println("att.handleData:", handle, "data:", hex.EncodeToString(buf))
} }
cd, err := a.findConnectionData(handle)
if err != nil {
return err
}
switch buf[0] { switch buf[0] {
case attOpError: case attOpError:
a.errored = true cd.errored = true
a.lastErrorOpcode = buf[1] cd.lastErrorOpcode = buf[1]
a.lastErrorHandle = binary.LittleEndian.Uint16(buf[2:]) cd.lastErrorHandle = binary.LittleEndian.Uint16(buf[2:])
a.lastErrorCode = buf[4] cd.lastErrorCode = buf[4]
if debug { if debug {
println("att.handleData: attOpERROR", a.lastErrorOpcode, a.lastErrorCode) println("att.handleData: attOpERROR", handle, cd.lastErrorOpcode, cd.lastErrorCode)
} }
return ErrATTOp return ErrATTOp
@ -517,7 +540,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
} }
// save mtu for connection // save mtu for connection
a.mtu = mtu cd.mtu = mtu
var b [3]byte var b [3]byte
b[0] = attOpMTUResponse b[0] = attOpMTUResponse
@ -531,8 +554,8 @@ func (a *att) handleData(handle uint16, buf []byte) error {
if debug { if debug {
println("att.handleData: attOpMTUResponse") println("att.handleData: attOpMTUResponse")
} }
a.responded = true cd.responded = true
a.mtu = binary.LittleEndian.Uint16(buf[1:]) cd.mtu = binary.LittleEndian.Uint16(buf[1:])
case attOpFindInfoReq: case attOpFindInfoReq:
if debug { if debug {
@ -548,7 +571,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
if debug { if debug {
println("att.handleData: attOpFindInfoResponse") println("att.handleData: attOpFindInfoResponse")
} }
a.responded = true cd.responded = true
lengthPerDescriptor := int(buf[1]) lengthPerDescriptor := int(buf[1])
@ -560,7 +583,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
println("att.handleData: descriptor", d.handle, hex.EncodeToString(d.data)) println("att.handleData: descriptor", d.handle, hex.EncodeToString(d.data))
} }
a.descriptors = append(a.descriptors, d) cd.descriptors = append(cd.descriptors, d)
} }
case attOpFindByTypeReq: case attOpFindByTypeReq:
@ -583,7 +606,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
if debug { if debug {
println("att.handleData: attOpReadByTypeResponse") println("att.handleData: attOpReadByTypeResponse")
} }
a.responded = true cd.responded = true
lengthPerCharacteristic := int(buf[1]) lengthPerCharacteristic := int(buf[1])
@ -595,7 +618,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
println("att.handleData: characteristic", c.startHandle, c.properties, c.valueHandle, c.uuid.String()) println("att.handleData: characteristic", c.startHandle, c.properties, c.valueHandle, c.uuid.String())
} }
a.characteristics = append(a.characteristics, c) cd.characteristics = append(cd.characteristics, c)
} }
return nil return nil
@ -615,7 +638,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
if debug { if debug {
println("att.handleData: attOpReadByGroupResponse") println("att.handleData: attOpReadByGroupResponse")
} }
a.responded = true cd.responded = true
lengthPerService := int(buf[1]) lengthPerService := int(buf[1])
@ -627,7 +650,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
println("att.handleData: service", service.startHandle, service.endHandle, service.uuid.String()) println("att.handleData: service", service.startHandle, service.endHandle, service.uuid.String())
} }
a.services = append(a.services, service) cd.services = append(cd.services, service)
} }
return nil return nil
@ -649,8 +672,8 @@ func (a *att) handleData(handle uint16, buf []byte) error {
if debug { if debug {
println("att.handleData: attOpReadResponse") println("att.handleData: attOpReadResponse")
} }
a.responded = true cd.responded = true
a.value = append(a.value, buf[1:]...) cd.value = append(cd.value, buf[1:]...)
case attOpWriteReq: case attOpWriteReq:
if debug { if debug {
@ -669,7 +692,7 @@ func (a *att) handleData(handle uint16, buf []byte) error {
if debug { if debug {
println("att.handleData: attOpWriteResponse") println("att.handleData: attOpWriteResponse")
} }
a.responded = true cd.responded = true
case attOpPrepWriteReq: case attOpPrepWriteReq:
if debug { if debug {
@ -798,7 +821,7 @@ func (a *att) handleReadByTypeReq(handle, start, end uint16, uuid shortUUID) err
pos = 2 pos = 2
response[1] = 0 response[1] = 0
for _, c := range a.characteristics { for _, c := range a.localCharacteristics {
if debug { if debug {
println("handleReadByTypeReq: looking at characteristic", c.startHandle, c.uuid.String()) println("handleReadByTypeReq: looking at characteristic", c.startHandle, c.uuid.String())
} }
@ -1021,16 +1044,28 @@ func (a *att) handleWriteReq(handle, attrHandle uint16, data []byte) error {
return a.sendError(handle, attOpWriteReq, attrHandle, attErrorWriteNotPermitted) return a.sendError(handle, attOpWriteReq, attrHandle, attErrorWriteNotPermitted)
} }
func (a *att) clearResponse() { func (a *att) clearResponse(handle uint16) error {
a.responded = false cd, err := a.findConnectionData(handle)
a.errored = false if err != nil {
a.lastErrorOpcode = 0 return err
a.lastErrorHandle = 0 }
a.lastErrorCode = 0
a.value = []byte{} cd.responded = false
cd.errored = false
cd.lastErrorOpcode = 0
cd.lastErrorHandle = 0
cd.lastErrorCode = 0
cd.value = []byte{}
return nil
} }
func (a *att) waitUntilResponse() error { func (a *att) waitUntilResponse(handle uint16) error {
cd, err := a.findConnectionData(handle)
if err != nil {
return err
}
start := time.Now().UnixNano() start := time.Now().UnixNano()
for { for {
if err := a.hci.poll(); err != nil { if err := a.hci.poll(); err != nil {
@ -1038,15 +1073,14 @@ func (a *att) waitUntilResponse() error {
} }
switch { switch {
case a.responded: case cd.responded:
return nil return nil
case (time.Now().UnixNano()-start)/int64(time.Second) > defaultTimeoutSeconds:
return ErrATTTimeout
default: default:
// check for timeout // check for timeout
if (time.Now().UnixNano()-start)/int64(time.Second) > 3 {
break
}
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
} }
} }
@ -1066,15 +1100,28 @@ func (a *att) poll() error {
} }
func (a *att) addConnection(handle uint16) error { func (a *att) addConnection(handle uint16) error {
if debug {
println("att.addConnection:", handle)
}
a.connections = append(a.connections, handle) a.connections = append(a.connections, handle)
a.connectionsData[handle] = &connectData{
services: []rawService{},
characteristics: []rawCharacteristic{},
value: []byte{},
}
return nil return nil
} }
func (a *att) removeConnection(handle uint16) error { func (a *att) removeConnection(handle uint16) error {
if debug {
println("att.removeConnection:", handle)
}
for i := range a.connections { for i := range a.connections {
if a.connections[i] == handle { if a.connections[i] == handle {
a.connections = append(a.connections[:i], a.connections[i+1:]...) a.connections = append(a.connections[:i], a.connections[i+1:]...)
delete(a.connectionsData, handle)
break break
} }
} }
@ -1107,7 +1154,7 @@ func (a *att) addLocalService(start, end uint16, uuid UUID) {
} }
func (a *att) addLocalCharacteristic(startHandle uint16, properties CharacteristicPermissions, valueHandle uint16, uuid UUID, chr *Characteristic) { func (a *att) addLocalCharacteristic(startHandle uint16, properties CharacteristicPermissions, valueHandle uint16, uuid UUID, chr *Characteristic) {
a.characteristics = append(a.characteristics, a.localCharacteristics = append(a.localCharacteristics,
rawCharacteristic{ rawCharacteristic{
startHandle: startHandle, startHandle: startHandle,
properties: uint8(properties), properties: uint8(properties),
@ -1128,11 +1175,29 @@ func (a *att) findAttribute(hdl uint16) *rawAttribute {
} }
func (a *att) findCharacteristic(hdl uint16) *rawCharacteristic { func (a *att) findCharacteristic(hdl uint16) *rawCharacteristic {
for i := range a.characteristics { for i := range a.localCharacteristics {
if a.characteristics[i].startHandle == hdl { if a.localCharacteristics[i].startHandle == hdl {
return &a.characteristics[i] return &a.localCharacteristics[i]
} }
} }
return nil return nil
} }
func (a *att) findConnectionData(handle uint16) (*connectData, error) {
cd, ok := a.connectionsData[handle]
if !ok {
return nil, ErrATTUnknownConnection
}
return cd, nil
}
func (a *att) lastError(handle uint16) (uint8, uint16, uint8) {
cd, err := a.findConnectionData(handle)
if err != nil {
return 0, 0, 0
}
return cd.lastErrorOpcode, cd.lastErrorHandle, cd.lastErrorCode
}

View file

@ -59,6 +59,11 @@ func (d Device) DiscoverServices(uuids []UUID) ([]DeviceService, error) {
services := make([]DeviceService, 0, maxDefaultServicesToDiscover) services := make([]DeviceService, 0, maxDefaultServicesToDiscover)
foundServices := make(map[UUID]DeviceService) foundServices := make(map[UUID]DeviceService)
cd, err := d.adapter.att.findConnectionData(d.handle)
if err != nil {
return nil, err
}
startHandle := uint16(0x0001) startHandle := uint16(0x0001)
endHandle := uint16(0xffff) endHandle := uint16(0xffff)
for endHandle == uint16(0xffff) { for endHandle == uint16(0xffff) {
@ -68,14 +73,14 @@ func (d Device) DiscoverServices(uuids []UUID) ([]DeviceService, error) {
} }
if debug { if debug {
println("found d.adapter.att.services", len(d.adapter.att.services)) println("found services", len(cd.services))
} }
if len(d.adapter.att.services) == 0 { if len(cd.services) == 0 {
break break
} }
for _, rawService := range d.adapter.att.services { for _, rawService := range cd.services {
if len(uuids) == 0 || rawService.uuid.isIn(uuids) { if len(uuids) == 0 || rawService.uuid.isIn(uuids) {
foundServices[rawService.uuid] = foundServices[rawService.uuid] =
DeviceService{ DeviceService{
@ -93,7 +98,7 @@ func (d Device) DiscoverServices(uuids []UUID) ([]DeviceService, error) {
} }
// reset raw services // reset raw services
d.adapter.att.services = []rawService{} cd.services = []rawService{}
// did we find them all? // did we find them all?
if len(foundServices) == len(uuids) { if len(foundServices) == len(uuids) {
@ -155,30 +160,35 @@ func (s DeviceService) DiscoverCharacteristics(uuids []UUID) ([]DeviceCharacteri
characteristics := make([]DeviceCharacteristic, 0, maxDefaultCharacteristicsToDiscover) characteristics := make([]DeviceCharacteristic, 0, maxDefaultCharacteristicsToDiscover)
foundCharacteristics := make(map[UUID]DeviceCharacteristic) foundCharacteristics := make(map[UUID]DeviceCharacteristic)
cd, err := s.device.adapter.att.findConnectionData(s.device.handle)
if err != nil {
return nil, err
}
startHandle := s.startHandle startHandle := s.startHandle
endHandle := s.endHandle endHandle := s.endHandle
for startHandle < endHandle { for startHandle < endHandle {
err := s.device.adapter.att.readByTypeReq(s.device.handle, startHandle, endHandle, gattCharacteristicUUID) err := s.device.adapter.att.readByTypeReq(s.device.handle, startHandle, endHandle, gattCharacteristicUUID)
switch { switch {
case err == ErrATTOp && case err == ErrATTOp:
s.device.adapter.att.lastErrorOpcode == attOpReadByTypeReq && opcode, _, errcode := s.device.adapter.att.lastError(s.device.handle)
s.device.adapter.att.lastErrorCode == attErrorAttrNotFound: if opcode == attOpReadByTypeReq && errcode == attErrorAttrNotFound {
// no characteristics found
// no characteristics found break
break }
case err != nil: case err != nil:
return nil, err return nil, err
} }
if debug { if debug {
println("found s.device.adapter.att.characteristics", len(s.device.adapter.att.characteristics)) println("found characteristics", len(cd.characteristics))
} }
if len(s.device.adapter.att.characteristics) == 0 { if len(cd.characteristics) == 0 {
break break
} }
for _, rawCharacteristic := range s.device.adapter.att.characteristics { for _, rawCharacteristic := range cd.characteristics {
if len(uuids) == 0 || rawCharacteristic.uuid.isIn(uuids) { if len(uuids) == 0 || rawCharacteristic.uuid.isIn(uuids) {
dc := DeviceCharacteristic{ dc := DeviceCharacteristic{
service: &s, service: &s,
@ -195,7 +205,7 @@ func (s DeviceService) DiscoverCharacteristics(uuids []UUID) ([]DeviceCharacteri
} }
// reset raw characteristics // reset raw characteristics
s.device.adapter.att.characteristics = []rawCharacteristic{} cd.characteristics = []rawCharacteristic{}
// did we find them all? // did we find them all?
if len(foundCharacteristics) == len(uuids) { if len(foundCharacteristics) == len(uuids) {
@ -305,11 +315,16 @@ func (c DeviceCharacteristic) Read(data []byte) (int, error) {
return 0, err return 0, err
} }
if len(c.service.device.adapter.att.value) == 0 { cd, err := c.service.device.adapter.att.findConnectionData(c.service.device.handle)
if err != nil {
return 0, err
}
if len(cd.value) == 0 {
return 0, errReadFailed return 0, errReadFailed
} }
copy(data, c.service.device.adapter.att.value) copy(data, cd.value)
return len(c.service.device.adapter.att.value), nil return len(cd.value), nil
} }