From 3e90718eb87151597e5a83ea1693ed659d4926e0 Mon Sep 17 00:00:00 2001 From: deadprogram Date: Fri, 26 Jan 2024 21:00:49 +0100 Subject: [PATCH] hci: multiple connections Signed-off-by: deadprogram --- att_hci.go | 197 ++++++++++++++++++++++++++++++++++----------------- gattc_hci.go | 49 ++++++++----- 2 files changed, 163 insertions(+), 83 deletions(-) diff --git a/att_hci.go b/att_hci.go index bed3d25..c69d367 100644 --- a/att_hci.go +++ b/att_hci.go @@ -67,12 +67,15 @@ const ( ) var ( - ErrATTTimeout = errors.New("bluetooth: ATT timeout") - ErrATTUnknownEvent = errors.New("bluetooth: ATT unknown event") - ErrATTUnknown = errors.New("bluetooth: ATT unknown error") - ErrATTOp = errors.New("bluetooth: ATT OP error") + ErrATTTimeout = errors.New("bluetooth: ATT timeout") + ErrATTUnknownEvent = errors.New("bluetooth: ATT unknown event") + ErrATTUnknown = errors.New("bluetooth: ATT unknown error") + ErrATTOp = errors.New("bluetooth: ATT OP error") + ErrATTUnknownConnection = errors.New("bluetooth: ATT unknown connection") ) +const defaultTimeoutSeconds = 10 + type rawService struct { startHandle uint16 endHandle uint16 @@ -249,9 +252,7 @@ func (a *rawAttribute) length() int { } } -type att struct { - hci *hci - busy sync.Mutex +type connectData struct { responded bool errored bool lastErrorOpcode uint8 @@ -263,26 +264,34 @@ type att struct { characteristics []rawCharacteristic descriptors []rawDescriptor value []byte - notifications chan rawNotification +} - connections []uint16 - lastHandle uint16 - attributes []rawAttribute - localServices []rawService +type att struct { + hci *hci + busy sync.Mutex + mtu uint16 + maxMTU uint16 + notifications chan rawNotification + + connections []uint16 + connectionsData map[uint16]*connectData + lastHandle uint16 + localServices []rawService + localCharacteristics []rawCharacteristic + attributes []rawAttribute } func newATT(hci *hci) *att { return &att{ - hci: hci, - services: []rawService{}, - characteristics: []rawCharacteristic{}, - value: []byte{}, - notifications: make(chan rawNotification, 32), - connections: []uint16{}, - lastHandle: 0x0001, - attributes: []rawAttribute{}, - localServices: []rawService{}, - maxMTU: 248, + hci: hci, + localCharacteristics: []rawCharacteristic{}, + notifications: make(chan rawNotification, 32), + connections: []uint16{}, + connectionsData: make(map[uint16]*connectData), + lastHandle: 0x0001, + attributes: []rawAttribute{}, + localServices: []rawService{}, + maxMTU: 248, } } @@ -304,7 +313,7 @@ func (a *att) readByGroupReq(connectionHandle, startHandle, endHandle uint16, uu return err } - return a.waitUntilResponse() + return a.waitUntilResponse(connectionHandle) } func (a *att) readByTypeReq(connectionHandle, startHandle, endHandle uint16, typ uint16) error { @@ -325,7 +334,7 @@ func (a *att) readByTypeReq(connectionHandle, startHandle, endHandle uint16, typ return err } - return a.waitUntilResponse() + return a.waitUntilResponse(connectionHandle) } func (a *att) findInfoReq(connectionHandle, startHandle, endHandle uint16) error { @@ -345,7 +354,7 @@ func (a *att) findInfoReq(connectionHandle, startHandle, endHandle uint16) error return err } - return a.waitUntilResponse() + return a.waitUntilResponse(connectionHandle) } func (a *att) readReq(connectionHandle, valueHandle uint16) error { @@ -364,7 +373,7 @@ func (a *att) readReq(connectionHandle, valueHandle uint16) error { return err } - return a.waitUntilResponse() + return a.waitUntilResponse(connectionHandle) } func (a *att) writeCmd(connectionHandle, valueHandle uint16, data []byte) error { @@ -402,7 +411,7 @@ func (a *att) writeReq(connectionHandle, valueHandle uint16, data []byte) error return err } - return a.waitUntilResponse() + return a.waitUntilResponse(connectionHandle) } func (a *att) mtuReq(connectionHandle uint16) error { @@ -410,18 +419,23 @@ func (a *att) mtuReq(connectionHandle uint16) error { println("att.mtuReq:", connectionHandle) } + cd, err := a.findConnectionData(connectionHandle) + if err != nil { + return err + } + a.busy.Lock() defer a.busy.Unlock() var b [3]byte b[0] = attOpMTUReq - binary.LittleEndian.PutUint16(b[1:], a.mtu) + binary.LittleEndian.PutUint16(b[1:], cd.mtu) if err := a.sendReq(connectionHandle, b[:]); err != nil { return err } - return a.waitUntilResponse() + return a.waitUntilResponse(connectionHandle) } func (a *att) setMaxMTU(mtu uint16) error { @@ -431,7 +445,9 @@ func (a *att) setMaxMTU(mtu uint16) error { } func (a *att) sendReq(handle uint16, data []byte) error { - a.clearResponse() + if err := a.clearResponse(handle); err != nil { + return err + } if debug { println("att.sendReq:", handle, "data:", hex.EncodeToString(data)) @@ -470,7 +486,9 @@ func (a *att) sendNotification(handle uint16, data []byte) error { } func (a *att) sendError(handle uint16, opcode uint8, hdl uint16, code uint8) error { - a.clearResponse() + if err := a.clearResponse(handle); err != nil { + return err + } if debug { println("att.sendError:", handle, "data:", opcode, hdl, code) @@ -494,15 +512,20 @@ func (a *att) handleData(handle uint16, buf []byte) error { println("att.handleData:", handle, "data:", hex.EncodeToString(buf)) } + cd, err := a.findConnectionData(handle) + if err != nil { + return err + } + switch buf[0] { case attOpError: - a.errored = true - a.lastErrorOpcode = buf[1] - a.lastErrorHandle = binary.LittleEndian.Uint16(buf[2:]) - a.lastErrorCode = buf[4] + cd.errored = true + cd.lastErrorOpcode = buf[1] + cd.lastErrorHandle = binary.LittleEndian.Uint16(buf[2:]) + cd.lastErrorCode = buf[4] if debug { - println("att.handleData: attOpERROR", a.lastErrorOpcode, a.lastErrorCode) + println("att.handleData: attOpERROR", handle, cd.lastErrorOpcode, cd.lastErrorCode) } return ErrATTOp @@ -517,7 +540,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { } // save mtu for connection - a.mtu = mtu + cd.mtu = mtu var b [3]byte b[0] = attOpMTUResponse @@ -531,8 +554,8 @@ func (a *att) handleData(handle uint16, buf []byte) error { if debug { println("att.handleData: attOpMTUResponse") } - a.responded = true - a.mtu = binary.LittleEndian.Uint16(buf[1:]) + cd.responded = true + cd.mtu = binary.LittleEndian.Uint16(buf[1:]) case attOpFindInfoReq: if debug { @@ -548,7 +571,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { if debug { println("att.handleData: attOpFindInfoResponse") } - a.responded = true + cd.responded = true lengthPerDescriptor := int(buf[1]) @@ -560,7 +583,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { println("att.handleData: descriptor", d.handle, hex.EncodeToString(d.data)) } - a.descriptors = append(a.descriptors, d) + cd.descriptors = append(cd.descriptors, d) } case attOpFindByTypeReq: @@ -583,7 +606,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { if debug { println("att.handleData: attOpReadByTypeResponse") } - a.responded = true + cd.responded = true lengthPerCharacteristic := int(buf[1]) @@ -595,7 +618,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { println("att.handleData: characteristic", c.startHandle, c.properties, c.valueHandle, c.uuid.String()) } - a.characteristics = append(a.characteristics, c) + cd.characteristics = append(cd.characteristics, c) } return nil @@ -615,7 +638,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { if debug { println("att.handleData: attOpReadByGroupResponse") } - a.responded = true + cd.responded = true lengthPerService := int(buf[1]) @@ -627,7 +650,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { println("att.handleData: service", service.startHandle, service.endHandle, service.uuid.String()) } - a.services = append(a.services, service) + cd.services = append(cd.services, service) } return nil @@ -649,8 +672,8 @@ func (a *att) handleData(handle uint16, buf []byte) error { if debug { println("att.handleData: attOpReadResponse") } - a.responded = true - a.value = append(a.value, buf[1:]...) + cd.responded = true + cd.value = append(cd.value, buf[1:]...) case attOpWriteReq: if debug { @@ -669,7 +692,7 @@ func (a *att) handleData(handle uint16, buf []byte) error { if debug { println("att.handleData: attOpWriteResponse") } - a.responded = true + cd.responded = true case attOpPrepWriteReq: if debug { @@ -798,7 +821,7 @@ func (a *att) handleReadByTypeReq(handle, start, end uint16, uuid shortUUID) err pos = 2 response[1] = 0 - for _, c := range a.characteristics { + for _, c := range a.localCharacteristics { if debug { println("handleReadByTypeReq: looking at characteristic", c.startHandle, c.uuid.String()) } @@ -1021,16 +1044,28 @@ func (a *att) handleWriteReq(handle, attrHandle uint16, data []byte) error { return a.sendError(handle, attOpWriteReq, attrHandle, attErrorWriteNotPermitted) } -func (a *att) clearResponse() { - a.responded = false - a.errored = false - a.lastErrorOpcode = 0 - a.lastErrorHandle = 0 - a.lastErrorCode = 0 - a.value = []byte{} +func (a *att) clearResponse(handle uint16) error { + cd, err := a.findConnectionData(handle) + if err != nil { + return err + } + + cd.responded = false + cd.errored = false + cd.lastErrorOpcode = 0 + cd.lastErrorHandle = 0 + cd.lastErrorCode = 0 + cd.value = []byte{} + + return nil } -func (a *att) waitUntilResponse() error { +func (a *att) waitUntilResponse(handle uint16) error { + cd, err := a.findConnectionData(handle) + if err != nil { + return err + } + start := time.Now().UnixNano() for { if err := a.hci.poll(); err != nil { @@ -1038,15 +1073,14 @@ func (a *att) waitUntilResponse() error { } switch { - case a.responded: + case cd.responded: return nil + case (time.Now().UnixNano()-start)/int64(time.Second) > defaultTimeoutSeconds: + return ErrATTTimeout + default: // check for timeout - if (time.Now().UnixNano()-start)/int64(time.Second) > 3 { - break - } - time.Sleep(5 * time.Millisecond) } } @@ -1066,15 +1100,28 @@ func (a *att) poll() error { } func (a *att) addConnection(handle uint16) error { + if debug { + println("att.addConnection:", handle) + } a.connections = append(a.connections, handle) + a.connectionsData[handle] = &connectData{ + services: []rawService{}, + characteristics: []rawCharacteristic{}, + value: []byte{}, + } return nil } func (a *att) removeConnection(handle uint16) error { + if debug { + println("att.removeConnection:", handle) + } + for i := range a.connections { if a.connections[i] == handle { a.connections = append(a.connections[:i], a.connections[i+1:]...) + delete(a.connectionsData, handle) break } } @@ -1107,7 +1154,7 @@ func (a *att) addLocalService(start, end uint16, uuid UUID) { } func (a *att) addLocalCharacteristic(startHandle uint16, properties CharacteristicPermissions, valueHandle uint16, uuid UUID, chr *Characteristic) { - a.characteristics = append(a.characteristics, + a.localCharacteristics = append(a.localCharacteristics, rawCharacteristic{ startHandle: startHandle, properties: uint8(properties), @@ -1128,11 +1175,29 @@ func (a *att) findAttribute(hdl uint16) *rawAttribute { } func (a *att) findCharacteristic(hdl uint16) *rawCharacteristic { - for i := range a.characteristics { - if a.characteristics[i].startHandle == hdl { - return &a.characteristics[i] + for i := range a.localCharacteristics { + if a.localCharacteristics[i].startHandle == hdl { + return &a.localCharacteristics[i] } } return nil } + +func (a *att) findConnectionData(handle uint16) (*connectData, error) { + cd, ok := a.connectionsData[handle] + if !ok { + return nil, ErrATTUnknownConnection + } + + return cd, nil +} + +func (a *att) lastError(handle uint16) (uint8, uint16, uint8) { + cd, err := a.findConnectionData(handle) + if err != nil { + return 0, 0, 0 + } + + return cd.lastErrorOpcode, cd.lastErrorHandle, cd.lastErrorCode +} diff --git a/gattc_hci.go b/gattc_hci.go index 48affc5..30bb251 100644 --- a/gattc_hci.go +++ b/gattc_hci.go @@ -59,6 +59,11 @@ func (d Device) DiscoverServices(uuids []UUID) ([]DeviceService, error) { services := make([]DeviceService, 0, maxDefaultServicesToDiscover) foundServices := make(map[UUID]DeviceService) + cd, err := d.adapter.att.findConnectionData(d.handle) + if err != nil { + return nil, err + } + startHandle := uint16(0x0001) endHandle := uint16(0xffff) for endHandle == uint16(0xffff) { @@ -68,14 +73,14 @@ func (d Device) DiscoverServices(uuids []UUID) ([]DeviceService, error) { } if debug { - println("found d.adapter.att.services", len(d.adapter.att.services)) + println("found services", len(cd.services)) } - if len(d.adapter.att.services) == 0 { + if len(cd.services) == 0 { break } - for _, rawService := range d.adapter.att.services { + for _, rawService := range cd.services { if len(uuids) == 0 || rawService.uuid.isIn(uuids) { foundServices[rawService.uuid] = DeviceService{ @@ -93,7 +98,7 @@ func (d Device) DiscoverServices(uuids []UUID) ([]DeviceService, error) { } // reset raw services - d.adapter.att.services = []rawService{} + cd.services = []rawService{} // did we find them all? if len(foundServices) == len(uuids) { @@ -155,30 +160,35 @@ func (s DeviceService) DiscoverCharacteristics(uuids []UUID) ([]DeviceCharacteri characteristics := make([]DeviceCharacteristic, 0, maxDefaultCharacteristicsToDiscover) foundCharacteristics := make(map[UUID]DeviceCharacteristic) + cd, err := s.device.adapter.att.findConnectionData(s.device.handle) + if err != nil { + return nil, err + } + startHandle := s.startHandle endHandle := s.endHandle for startHandle < endHandle { err := s.device.adapter.att.readByTypeReq(s.device.handle, startHandle, endHandle, gattCharacteristicUUID) switch { - case err == ErrATTOp && - s.device.adapter.att.lastErrorOpcode == attOpReadByTypeReq && - s.device.adapter.att.lastErrorCode == attErrorAttrNotFound: - - // no characteristics found - break + case err == ErrATTOp: + opcode, _, errcode := s.device.adapter.att.lastError(s.device.handle) + if opcode == attOpReadByTypeReq && errcode == attErrorAttrNotFound { + // no characteristics found + break + } case err != nil: return nil, err } if debug { - println("found s.device.adapter.att.characteristics", len(s.device.adapter.att.characteristics)) + println("found characteristics", len(cd.characteristics)) } - if len(s.device.adapter.att.characteristics) == 0 { + if len(cd.characteristics) == 0 { break } - for _, rawCharacteristic := range s.device.adapter.att.characteristics { + for _, rawCharacteristic := range cd.characteristics { if len(uuids) == 0 || rawCharacteristic.uuid.isIn(uuids) { dc := DeviceCharacteristic{ service: &s, @@ -195,7 +205,7 @@ func (s DeviceService) DiscoverCharacteristics(uuids []UUID) ([]DeviceCharacteri } // reset raw characteristics - s.device.adapter.att.characteristics = []rawCharacteristic{} + cd.characteristics = []rawCharacteristic{} // did we find them all? if len(foundCharacteristics) == len(uuids) { @@ -305,11 +315,16 @@ func (c DeviceCharacteristic) Read(data []byte) (int, error) { return 0, err } - if len(c.service.device.adapter.att.value) == 0 { + cd, err := c.service.device.adapter.att.findConnectionData(c.service.device.handle) + if err != nil { + return 0, err + } + + if len(cd.value) == 0 { return 0, errReadFailed } - copy(data, c.service.device.adapter.att.value) + copy(data, cd.value) - return len(c.service.device.adapter.att.value), nil + return len(cd.value), nil }